diff --git a/.gitignore b/.gitignore index 048167c..9f412a4 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ dist/ CLAUDE.md TODO_REAL_TESTS.md server.log +install_*.log .claude/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index a8fc1fd..04a89b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,74 @@ # Changelog -## 1.1.1 — Pending +## 2.0.0-alpha.3 — 2025-09-08 -- Fix (Issue #27): Strict health completeness for multi-shard models in 1.x: - - Recognize both safetensors (`model.safetensors.index.json`) and PyTorch (`pytorch_model.bin.index.json`) JSON indices. - - Validate only the present format’s shards (exist, non-empty, not LFS pointers) to avoid false negatives. - - Aligns 1.x health behavior with 2.0.0-alpha.1 policy. - - Planned (Issue #31, under #29): Detect Framework/Type via HF Model Card (README front matter) and tokenizer config for non-`mlx-community` repos (lenient parsing). No CLI/JSON schema changes; focused unit tests; target 1.1.1-b2. +Port Issue #31 (lenient MLX detection) to 2.0; refine human list behavior. + +Hard split: 1.x code and tests have been removed from this branch to avoid confusion and license duality. Use the `main` branch for 1.x (MIT). + +### Added +- Detection helpers (README front‑matter + tokenizer): + - Framework=MLX when README front‑matter `tags` includes `mlx` or `library_name: mlx`, in addition to `mlx-community/*`. + - Type=chat when tokenizer has `chat_template`, or name hints (`instruct`/`chat`), or `config.model_type == 'chat'`. + - Unified `build_model_object(...)` used by `list` and `show` to ensure consistent fields. +- Tests: + - Offline: front‑matter and tokenizer detection for both `list` and `show`. + - Human output: verifies default/verbose/all filtering semantics. + - Live (opt-in): `tests_2.0/live/test_list_human_live.py` checks human list variants against a real HF cache (marker `-m live_list`). + - Push (offline): branch-missing tolerance and retry on "Invalid rev id" with `--create`. + +### Changed +- Human list (default): shows only MLX chat models (safer for run/server selection). +- Human list `--verbose`: shows all MLX models (chat + base). +- Human list `--all`: shows all frameworks (MLX, GGUF, PyTorch). +- `show` uses the same detection helpers as `list`; respects `HF_HOME` via `get_current_model_cache()`. + +### Docs +- SECURITY.md: clarified experimental push scope and network behavior (explicit only; no background traffic). +- README.md: added “Privacy & Network” bullet; updated version strings to alpha.3. + - README.md: noted hard split — 1.x lives on `main` (MIT), this branch is 2.x (Apache‑2.0). + +### Notes +- No JSON API schema changes; spec remains 0.1.3. + +### Fixed +- Push: tolerate missing target branches; with `--create`, proactively create the branch and retry the upload once. No‑op uploads still create the branch when `--create` is provided. + +## [1.1.1-beta.2] - 2025-09-06 + +### Feature: Lenient MLX Detection for Private Repos (Issue #31) +- Problem: `run` only accepted `mlx-community/*` models; private/cloned MLX repos (in MLX format) appeared as "PyTorch | base" and were rejected. +- Solution: Added README/tokenizer-based detection to recognize MLX/chat models outside `mlx-community`. +- Details: + - Tokenizer: If `tokenizer_config.json` contains a non-empty `chat_template` → Type = `chat` (highest priority). + - README front matter (YAML, lenient parse): + - `tags` contains `mlx` OR `library_name: mlx` → Framework = `MLX`. + - `pipeline_tag: text-generation` OR `tags` contain `chat`/`instruct` → Type = `chat`. + - `pipeline_tag: sentence-similarity` OR `tags` contain `embedding` → Type = `embedding`. + - Fallback unchanged: `.gguf` → `GGUF`; else `safetensors/bin` → `PyTorch`; else `Unknown`. Type fallback by name substrings (`instruct/chat` → chat; `embed` → embedding; else base). + +### CLI Behavior (Schema Unchanged) +- `mlxk show` now displays `Type: ` when detected. +- `mlxk list --all` includes a `TYPE` column; default `mlxk list` now shows chat-capable MLX models only (strict view). +- `mlxk run` now accepts MLX repos identified via README (not only `mlx-community/*`). + +### Implementation +- New helper: `mlx_knife/model_card.py` (no deps) to read README front matter and tokenizer hints; fully fail-safe. +- Updated detection in `mlx_knife/cache_utils.py`: + - `detect_framework(...)` consults README hints before file-type fallback. + - New `detect_model_type(...)` implements priority order. + - `run_model(...)` imports runner module for easier test monkeypatching. + +### Tests +- Added unit tests: `tests/unit/test_model_card_detection.py`. +- Server test stability and safety improvements: + - RAM-aware model gating now combines size-token heuristics with `mlxk show` data (disk size + quantization) for more reliable estimates. + - Fixed MoE size parsing (prefers tokens like `8x7B` over partial `7B` matches). + - Robust server process guard ensures clean shutdown on Ctrl-C/SIGTERM (prevents orphaned Python processes using excessive memory). + - Configurable safety/estimation factors via environment variables (see TESTING.md). +- All tests passing locally on Apple Silicon across Python 3.9–3.13: 166/166. + +Note: GitHub tag/version uses `1.1.1-beta.2`. PyPI release uses PEP 440 `1.1.1b2`. ## 2.0.0-alpha.2 — 2025-09-05 @@ -28,6 +90,23 @@ Experimental `push` (upload only) and documentation/testing refinements. ### Tests - Offline push tests added/extended, including dry-run planning; live push remains opt-in via `wet`/`live_push` markers and required env vars. +## [1.1.1-beta.1] - 2025-09-01 + +### Fix: Strict Health Completeness for Multi‑Shard Models (Issue #27) +- Problem: Health reported some multi‑part downloads as OK with missing/empty shards (false positives). +- Solution: Backported 2.0 health rules to 1.x with index‑aware validation, pattern detection, and robust corruption checks. +- Details: + - Config validation: `config.json` must exist and be a non‑empty JSON object. + - Index‑aware: If `model.safetensors.index.json` or `pytorch_model.bin.index.json` exists, every referenced shard must exist, be non‑empty, and not be a Git LFS pointer file. + - Pattern fallback policy: If pattern shards like `model-XXXXX-of-YYYYY.*` are present but no index file exists, the model is considered unhealthy (parity with 2.0 policy). + - Partial/tmp markers: Any `*.partial`, `*.tmp`, or names containing `partial` anywhere under the snapshot mark the model as unhealthy. + - LFS detection: Recursive scan flags suspiciously small files (<200B) that contain the Git LFS pointer header. + - Single‑file weights: Non‑empty `*.safetensors`, `*.bin`, or `*.gguf` without pattern shards remain supported and healthy if not LFS pointers. +- Impact: “Healthy” now reliably means “complete and usable” for automation and CLI workflows. +- Tests: Added `tests/unit/test_health_multishard.py` covering complete/missing/empty shards, pointer detection, pattern‑without‑index policy, partial markers, and PyTorch index parity. + +Note: GitHub tag/version uses `1.1.1-beta.1`. PyPI release uses PEP 440 `1.1.1b1`. + ## 2.0.0-alpha.1 — 2025-08-31 - New JSON-first CLI (`mlxk2`, `mlxk-json`); `--json` for machine-readable output (new vs 1.0.0). @@ -278,8 +357,4 @@ Experimental `push` (upload only) and documentation/testing refinements. ## Known Issues - See GitHub Issues for tracking -## 2.0.0-alpha.2 — 2025-09-04 -- Experimental: add `push` command (M0 upload-only) with hard excludes and `.hfignore` support -- Safety: require `--private` in CLI for alpha.2 to avoid accidental public uploads -- JSON: add `push` to schema; examples updated; short experimental disclaimer in responses -- Robustness: early validation for `pull` model names; improved CLI JSON errors for missing args + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d64ba03..78a4a10 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -81,14 +81,15 @@ Understanding what goes where: ``` Repository structure: -├── mlx_knife/ # Python package (→ PyPI) -├── tests/ # Test suite -├── simple_chat.html # Web interface (GitHub only) -├── README.md # User documentation -├── CONTRIBUTING.md # This file -├── TESTING.md # Testing guide -├── pyproject.toml # Build configuration -└── requirements.txt # Dependencies +├── mlxk2/ # 2.0 implementation (→ PyPI via mlxk-json) +├── tests_2.0/ # 2.0 test suite +├── docs/ # Documentation / ADRs +├── README.md # User documentation +├── CONTRIBUTING.md # This file +├── TESTING.md # Testing guide +├── pyproject.toml # Build configuration (dynamic version) +├── pyproject-mlxk-json.toml # Alternate build config (local/dev) +└── requirements.txt # Dev/test dependencies ``` **What goes where:** @@ -131,9 +132,9 @@ For detailed testing options, troubleshooting, and advanced workflows, see **[TE Please ensure all tests pass locally: ```bash # Complete test workflow -ruff check mlx_knife/ --fix # Fix code style -mypy mlx_knife/ # Check types -pytest tests/ # Run all tests +ruff check mlxk2/ --fix # Fix code style +mypy mlxk2/ # Check types +pytest -v # Run all 2.0 tests ``` Since we don't have CI/CD (MLX requires Apple Silicon), we rely on contributors to verify their changes locally. Please mention in your PR: @@ -170,8 +171,8 @@ Mention your Python version in the PR description. - Update documentation if needed 3. **Before submitting:** - - Run the full test suite locally: `pytest tests/` - - Run code quality checks: `ruff check mlx_knife/ --fix` + - Run the full test suite locally: `pytest -v` + - Run code quality checks: `ruff check mlxk2/ --fix` - Test with YOUR Python version (3.9+ required) - Update README.md if you've added features @@ -241,7 +242,10 @@ Feel free to open an issue with the "question" label or start a discussion. We'r ## License -By contributing, you agree that your contributions will be licensed under the MIT License. +- For 2.x (`mlxk2`, this branch): By contributing, you agree that your contributions will be licensed under the Apache License, Version 2.0. +- For 1.x (`main`): By contributing, you agree that your contributions will be licensed under the MIT License. + +Please ensure you have the right to contribute the code under these terms. We recommend including a Developer Certificate of Origin (DCO) “Signed-off-by” line in commits. --- diff --git a/LICENSE b/LICENSE index 2ec4020..1e32dfc 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,201 @@ -MIT License +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ -Copyright (c) 2025 The BROKE team 🦫 +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +1. Definitions. -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. +"License" shall mean the terms and conditions for use, reproduction, +and distribution as defined by Sections 1 through 9 of this document. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +"Licensor" shall mean the copyright owner or entity authorized by +the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all +other entities that control, are controlled by, or are under common +control with that entity. For the purposes of this definition, +"control" means (i) the power, direct or indirect, to cause the +direction or management of such entity, whether by contract or +otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity +exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, +including but not limited to software source code, documentation +source, and configuration files. + +"Object" form shall mean any form resulting from mechanical +transformation or translation of a Source form, including but not +limited to compiled object code, generated documentation, and +conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or +Object form, made available under the License, as indicated by a +copyright notice that is included in or attached to the work +(an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object +form, that is based on (or derived from) the Work and for which the +editorial revisions, annotations, elaborations, or other modifications +represent, as a whole, an original work of authorship. For the purposes +of this License, Derivative Works shall not include works that remain +separable from, or merely link (or bind by name) to the interfaces of, +the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including +the original version of the Work and any modifications or additions +to that Work or Derivative Works thereof, that is intentionally +submitted to Licensor for inclusion in the Work by the copyright owner +or by an individual or Legal Entity authorized to submit on behalf of +the copyright owner. For the purposes of this definition, "submitted" +means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, +and issue tracking systems that are managed by, or on behalf of, the +Licensor for the purpose of discussing and improving the Work, but +excluding communication that is conspicuously marked or otherwise +designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity +on behalf of whom a Contribution has been received by Licensor and +subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the +Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of +this License, each Contributor hereby grants to You a perpetual, +worldwide, non-exclusive, no-charge, royalty-free, irrevocable +(except as stated in this section) patent license to make, have made, +use, offer to sell, sell, import, and otherwise transfer the Work, +where such license applies only to those patent claims licensable +by such Contributor that are necessarily infringed by their +Contribution(s) alone or by combination of their Contribution(s) +with the Work to which such Contribution(s) was submitted. If You +institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work +or a Contribution incorporated within the Work constitutes direct +or contributory patent infringement, then any patent licenses +granted to You under this License for that Work shall terminate +as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the +Work or Derivative Works thereof in any medium, with or without +modifications, and in Source or Object form, provided that You +meet the following conditions: + +(a) You must give any other recipients of the Work or +Derivative Works a copy of this License; and + +(b) You must cause any modified files to carry prominent notices +stating that You changed the files; and + +(c) You must retain, in the Source form of any Derivative Works +that You distribute, all copyright, patent, trademark, and +attribution notices from the Source form of the Work, +excluding those notices that do not pertain to any part of +the Derivative Works; and + +(d) If the Work includes a "NOTICE" text file as part of its +distribution, then any Derivative Works that You distribute must +include a readable copy of the attribution notices contained +within such NOTICE file, excluding those notices that do not +pertain to any part of the Derivative Works, in at least one +of the following places: within a NOTICE text file distributed +as part of the Derivative Works; within the Source form or +documentation, if provided along with the Derivative Works; or, +within a display generated by the Derivative Works, if and +wherever such third-party notices normally appear. The contents +of the NOTICE file are for informational purposes only and +do not modify the License. You may add Your own attribution +notices within Derivative Works that You distribute, alongside +or as an addendum to the NOTICE text from the Work, provided +that such additional attribution notices cannot be construed +as modifying the License. + +You may add Your own copyright statement to Your modifications and +may provide additional or different license terms and conditions +for use, reproduction, or distribution of Your modifications, or +for any such Derivative Works as a whole, provided Your use, +reproduction, and distribution of the Work otherwise complies with +the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, +any Contribution intentionally submitted for inclusion in the Work +by You to the Licensor shall be under the terms and conditions of +this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify +the terms of any separate license agreement you may have executed +with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade +names, trademarks, service marks, or product names of the Licensor, +except as required for reasonable and customary use in describing the +origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or +agreed to in writing, Licensor provides the Work (and each +Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied, including, without limitation, any warranties or conditions +of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A +PARTICULAR PURPOSE. You are solely responsible for determining the +appropriateness of using or redistributing the Work and assume any +risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, +whether in tort (including negligence), contract, or otherwise, +unless required by applicable law (such as deliberate and grossly +negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, +incidental, or consequential damages of any character arising as a +result of this License or out of the use or inability to use the +Work (including but not limited to damages for loss of goodwill, +work stoppage, computer failure or malfunction, or any and all +other commercial damages or losses), even if such Contributor +has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing +the Work or Derivative Works thereof, You may choose to offer, +and charge a fee for, acceptance of support, warranty, indemnity, +or other liability obligations and/or rights consistent with this +License. However, in accepting such obligations, You may act only +on Your own behalf and on Your sole responsibility, not on behalf +of any other Contributor, and only if You agree to indemnify, +defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason +of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following +boilerplate notice, with the fields enclosed by brackets "[]" +replaced with your own identifying information. (Don't include +the brackets!) The text should be enclosed in the appropriate +comment syntax for the file format. We also recommend that a +file or class name and description of purpose be included on the +same "printed page" as the copyright notice for easier +identification within third-party archives. + +Copyright [2025] [The BROKE team] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md index ad99c45..3580ead 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# BROKE Logo MLX-Knife 2.0.0-alpha.2 +# BROKE Logo MLX-Knife 2.0.0-alpha.3

MLX Knife Demo @@ -6,18 +6,18 @@ ## New: JSON-First Model Management for Automation & Scripting -> **🚧 Alpha Development:** Server and run are not included yet in 2.0.0-alpha.2. Use [MLX-Knife 1.1.0](https://github.com/mzau/mlx-knife/tree/main) for those features. +> **🚧 Alpha Development:** Server and run are not included yet in 2.0.0-alpha.3. Use [MLX-Knife 1.1.0](https://github.com/mzau/mlx-knife/tree/main) for those features. **Stable Version: 1.1.0** -[![GitHub Release](https://img.shields.io/badge/version-2.0.0--alpha.2-orange.svg)](https://github.com/mzau/mlx-knife/releases) +[![GitHub Release](https://img.shields.io/badge/version-2.0.0--alpha.3-orange.svg)](https://github.com/mzau/mlx-knife/releases) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/) [![Apple Silicon](https://img.shields.io/badge/Apple%20Silicon-M1%2FM2%2FM3-green.svg)](https://support.apple.com/en-us/HT211814) [![MLX](https://img.shields.io/badge/MLX-Latest-orange.svg)](https://github.com/ml-explore/mlx) [![Sponsor mlx-knife](https://img.shields.io/badge/Sponsor-mlx--knife-ff69b4?logo=github-sponsors&logoColor=white)](https://github.com/sponsors/mzau) -[![Tests](https://img.shields.io/badge/tests-45%2F45%20passing-brightgreen.svg)](#testing) +[![Tests](https://img.shields.io/badge/tests-98%2F98%20passing-brightgreen.svg)](#testing) ## Features @@ -25,9 +25,10 @@ - **List & Manage Models**: Browse your HuggingFace cache with MLX-specific filtering - **Model Information**: Detailed model metadata including quantization info - **Download Models**: Pull models from HuggingFace with progress tracking -- **Run Models**: Native MLX execution with streaming and chat modes (version 1.0.0 stable only) +- **Run Models**: Native MLX execution with streaming and chat modes (version 1.1.0 stable only) - **Health Checks**: Verify model integrity and completeness - **Cache Management**: Clean up and organize your model storage +- **Privacy & Network**: No background network or telemetry; only explicit Hugging Face interactions when you run pull or the experimental push. ### Requirements - macOS with Apple Silicon (M1/M2/M3) @@ -61,20 +62,24 @@ mlxk2 list --all --verbose mlxk2 health mlxk2 show "mlx-community/Phi-3-mini-4k-instruct-4bit" +### List filters (human) +- `list`: shows MLX chat models only (safe default for run/server selection) +- `list --verbose`: shows all MLX models (chat + base) +- `list --all`: shows all frameworks (MLX, GGUF, PyTorch) +- `list --all --verbose`: same selection as `--all`, with fuller names/details + +Note: JSON output is unaffected by these human-only filters. + ## JSON API mlxk2 list --json | jq '.data.models[].name' mlxk2 health --json | jq '.data.summary' mlxk2 show "Phi-3-mini" --json | jq '.data.model' ``` -## Differences vs 1.0.0 +## Compatibility Notes -- CLI: new entry points `mlxk2` and `mlxk-json` (1.0.0 used `mlxk`). -- Output: human output by default; add `--json` for machine-readable responses (new vs 1.0.0). -- List formatting: improved compact table with relative times in the Modified column (e.g., 3h ago) and a new Type column; compact MLX-only view by default. -- Flags (human-only): `--all` (all frameworks), `--health` (add Health column), `--verbose` (show full `org/model`). -- JSON API: current spec v0.1.3; CLI accepts `--json` after subcommands. -- Missing features (compared to 1.0.0): server and run are not included in 2.0 alpha.2 (use `mlxk` 1.x). +- 2.0 CLI is JSON-first with human output by default; use `--json` for API responses. +- Missing features vs 1.x: server and run are not included yet in 2.0 alpha.3 (use `mlxk` 1.x). ## ⚠️ Alpha Status Disclaimer @@ -142,8 +147,8 @@ This feature is not final and may change or be removed. pip install -e /path/to/mlx-knife # Verify installation -mlxk-json --version # → mlxk2 2.0.0-alpha.2 -mlxk2 --version # → mlxk2 2.0.0-alpha.2 +mlxk-json --version # → mlxk2 2.0.0-alpha.3 +mlxk2 --version # → mlxk2 2.0.0-alpha.3 ``` ### Parallel with MLX-Knife 1.x @@ -358,10 +363,10 @@ pytest tests/ -v # Current status: all current 2.0 tests pass (some optional schema tests may be skipped without extras) ``` -**Revolutionary Test Architecture:** +**Test Architecture:** - **Isolated Cache System** - Zero risk to user data - **Atomic Context Switching** - Production/test cache separation -- **Comprehensive Mock Models** - Realistic test scenarios +- **Mock Models** - Realistic test scenarios - **Edge Case Coverage** - All documented failure modes tested ## Known Issues & Limitations @@ -382,8 +387,8 @@ pytest tests/ -v ### Version Roadmap - **2.0.0-alpha** ← You are here (JSON API core complete) - **2.0.0-beta**: 6-8 weeks robust testing, production validation -- **2.0.0-rc**: Server/run features, full 1.x parity -- **2.0.0-stable**: Community validated, enterprise ready +- **2.0.0-rc**: Server/run features, full 1.x parity; CLI compatibility: `mlxk` alias alongside `mlxk2` +- **2.0.0-stable**: Stable release after RC feedback ### Architecture Decisions - **JSON-First**: All output structured for scripting and automation @@ -414,6 +419,14 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed guidelines. - **Discussions**: [GitHub Discussions](https://github.com/mzau/mlx-knife/discussions) - **API Specification**: [JSON API Specification](docs/json-api-specification.md) - **Documentation**: See `docs/` directory for technical details +- **Security Policy**: See [SECURITY.md](SECURITY.md) + +## License + +- 2.x (`mlxk2`, this branch): Apache License 2.0 — see `LICENSE` (root) and `mlxk2/NOTICE`. +- 1.x (`main` branch): MIT License — see `LICENSE` on `main`. + +Note: This branch is hard‑split for 2.0. The 1.x implementation and tests were removed here to avoid confusion and license duality; refer to the `main` branch for 1.x. **For production use**: Consider MLX-Knife 1.1.0 until 2.0.0-beta is available. @@ -425,7 +438,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed guidelines. --- -*MLX-Knife 2.0.0-alpha - Built for automation, tested for reliability, designed for the future.* +*MLX-Knife 2.0.0-alpha — JSON-first CLI for local model management.* ## Sponsors @@ -448,6 +461,6 @@ Special thanks to early supporters and users providing feedback during the 2.0 a

Made with ❤️ by The BROKE team BROKE Logo
- Version 2.0.0-alpha.2 | September 2025
+ Version 2.0.0-alpha.3 | September 2025
🔮 Next: BROKE Cluster for multi-node deployments

diff --git a/SECURITY.md b/SECURITY.md index f835c4d..a9e22f9 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -2,7 +2,7 @@ ## Overview -MLX Knife is designed to run locally on your Apple Silicon Mac. It prioritizes user privacy and security by keeping all model execution local. The only network activity is downloading models from HuggingFace (a trusted source). +MLX Knife is designed to run locally on your Apple Silicon Mac. It prioritizes user privacy and security by keeping all model execution local. Network activity is limited to explicit interactions with Hugging Face: downloading models (pull) and, in 2.0 alpha, an opt‑in experimental upload (push) when you run it explicitly. No background network traffic. ## Security Model @@ -11,13 +11,16 @@ MLX Knife is designed to run locally on your Apple Silicon Mac. It prioritizes u - ✅ Downloads models only from HuggingFace (trusted repository) - ✅ API server binds to localhost by default - ✅ No telemetry or usage tracking -- ✅ No external API calls (except HuggingFace for downloads) +- ✅ No external API calls (except explicit Hugging Face interactions: downloads via pull; optional upload via experimental push) +- ✅ Can upload a local workspace to Hugging Face only when you explicitly run `mlxk2 push` (experimental, opt‑in) ### What MLX Knife Doesn't Do -- ❌ No data is sent to external servers +- ❌ No data is sent to external servers automatically or in the background - ❌ No model outputs are logged or transmitted - ❌ No user tracking or analytics - ❌ No automatic updates or phone-home features + + Note: The experimental `push` command will upload files from a user‑selected local folder to Hugging Face only when you run it explicitly and provide credentials. It never runs implicitly. ## Reporting Security Vulnerabilities @@ -84,6 +87,36 @@ mlxk server --host 0.0.0.0 --port 8000 - Safe operations: reads (`list`, `health`, `show`) are always safe; coordinate writes (`pull`, `rm`) in maintenance windows - Test safeguards: the test suite places a sentinel in the test cache and enforces deletion guards to prevent accidental user-cache modification +### Experimental Push (`mlxk2 push`) + +The 2.0 alpha introduces an experimental upload capability. Treat it as opt‑in, with explicit user control. + +#### Scope and defaults +- Upload‑only (M0): pushes a specified local folder to a Hugging Face model repo via `huggingface_hub.upload_folder`. +- Requires `HF_TOKEN`; in alpha, `--private` is required to reduce accidental exposure. +- Default branch is `main` (overridable with `--branch`). No manifests or content validation yet. +- Honors default ignore patterns and merges project `.hfignore` when present (e.g., excludes `.git/`, `.venv/`, `__pycache__/`, `.DS_Store`). + +#### Privacy and boundaries +- Only files under the path you provide are considered; push does not scan your global caches or home directory. +- No prompts, logs, or runtime telemetry are uploaded. +- No background activity: nothing is sent unless you invoke `mlxk2 push`. + +#### Safety controls +- Preflight without network: `--check-only` analyzes the local folder for obvious issues (e.g., missing shards, LFS pointers). +- Plan without committing: `--dry-run` lists prospective adds/deletes vs remote (no upload performed). +- Use restricted tokens and test repos when validating; prefer `--private` and organization/user repos you control. + +#### Risks and mitigations +- Risk: Accidental upload of sensitive files included in the folder. + - Mitigate with a minimal, dedicated workspace, `.hfignore`, and `--check-only`/`--dry-run` before pushing. +- Risk: Pushing incomplete or corrupted weights. + - Mitigate by reviewing `workspace_health` from `--check-only` and model card requirements before uploading. + +#### Network and logging +- Network egress targets only Hugging Face over HTTPS; no third‑party endpoints. +- In `--json` mode, hub logs may be captured in output for diagnostics; they are not transmitted elsewhere by MLX Knife. + ## Security Best Practices ### For Users: diff --git a/TESTING.md b/TESTING.md index f054bef..69a017a 100644 --- a/TESTING.md +++ b/TESTING.md @@ -2,10 +2,10 @@ ## Current Status -✅ **150/150 tests passing** (August 2025) - **STABLE RELEASE** 🚀 +✅ **98/98 tests passing** (September 2025) — 2.0.0-alpha.3; 9 skipped (opt-in) ✅ **Apple Silicon verified** (M1/M2/M3) ✅ **Python 3.9-3.13 compatible** -✅ **Production ready** - comprehensive testing with real model execution +✅ **Alpha (CLI/JSON)** — default suite green locally (no inference) ✅ **Isolated test system** - user cache stays pristine with temp cache isolation ✅ **3-category test strategy** - optimized for performance and safety @@ -15,32 +15,34 @@ # Install package + tests pip install -e .[test] -# Download test model (optional - most tests use isolated cache) -mlxk pull mlx-community/Phi-3-mini-4k-instruct-4bit +# Download test model (optional; most 2.0 tests use isolated cache) +# Only needed for opt-in live tests or local experiments +# mlxk pull mlx-community/Phi-3-mini-4k-instruct-4bit -# Run 2.0 tests (default: tests_2.0/) +# Run 2.0 tests (default discovery: tests_2.0/) pytest -v -# Run legacy 1.x suite explicitly (not maintained here) -pytest tests/ -v - -# Fast unit tests only -pytest tests/unit/ +# Live tests (opt-in; not part of default): +# - Live push (requires env): +# export MLXK2_LIVE_PUSH=1 +# export HF_TOKEN=...; export MLXK2_LIVE_REPO=org/model; export MLXK2_LIVE_WORKSPACE=/abs/path +# pytest -q -m live_push +# - Live list (uses your HF_HOME; requires at least one MLX chat + one MLX base in cache): +# export HF_HOME=/path/to/huggingface/cache +# pytest -q -m live_list # Before committing -ruff check mlx_knife/ --fix && mypy mlx_knife/ && pytest +ruff check mlxk2/ --fix && mypy mlxk2/ && pytest -v ``` ## Why Local Testing? -MLX Knife requires **Apple Silicon hardware** and **real MLX models** for comprehensive testing: +MLX Knife tests fall into two categories for 2.0: -- **Hardware Requirement**: MLX framework only runs on Apple Silicon (M1/M2/M3) -- **Model Requirement**: Tests use actual models (4GB+) for realistic validation -- **Industry Standard**: Local testing is normal for MLX projects -- **Quality Assurance**: Real hardware testing ensures actual functionality +- CLI/JSON tests (default): Run on any supported Python on macOS; no model inference required; use an isolated HF cache (no network). +- Live/Inference tests (opt-in; future RC for server/run): Require Apple Silicon (M1/M2/M3) and real models. -This approach ensures our tests reflect real-world usage, not mocked behavior. +For push/list live tests in 2.0 alpha, see the opt-in commands above. ## Test Structure @@ -49,22 +51,34 @@ This approach ensures our tests reflect real-world usage, not mocked behavior. ``` tests_2.0/ ├── __init__.py -├── conftest.py # Isolated test cache, fixtures -├── test_edge_cases_adr002.py # Edge-case naming, ADR-002 -├── test_health_multifile.py # Multi-file health completeness -├── test_integration.py # Model resolution, health integration -├── test_issue_27.py # Health policy consistency -├── test_model_naming.py # Pattern/@hash parsing and resolution -├── test_robustness.py # General robustness tests -├── test_json_api_list.py # JSON API v0.1.2 (list contract) -├── test_json_api_show.py # JSON API v0.1.2 (show contract) -└── spec/ - ├── test_cli_version_output.py # version command JSON shape - ├── test_spec_doc_examples_validate.py # docs examples vs schema (jsonschema) - └── test_spec_version_sync.py # docs version == code constant +├── conftest.py # Isolated test cache, fixtures +├── test_human_output.py # Human rendering (list/health) +├── test_detection_readme_tokenizer.py # Issue #31 (README/tokenizer detection) +├── test_json_api_list.py # JSON API (list contract) +├── test_json_api_show.py # JSON API (show contract) +├── test_edge_cases_adr002.py # Edge-case naming, ADR-002 +├── test_health_multifile.py # Multi-file health completeness +├── test_integration.py # Model resolution, health integration +├── test_issue_27.py # Health policy consistency +├── test_model_naming.py # Pattern/@hash parsing and resolution +├── test_robustness.py # General robustness tests +├── test_cli_push_args.py # Push CLI args (offline) +├── test_push_minimal.py # Push minimal (offline) +├── test_push_extended.py # Push extended (offline) +├── test_push_dry_run.py # Push dry-run planning (offline) +├── test_push_workspace_check.py # Push check-only (offline) +├── spec/ +│ ├── test_cli_version_output.py # version command JSON shape +│ ├── test_spec_doc_examples_validate.py # docs examples vs schema +│ ├── test_spec_version_sync.py # docs version == code constant +│ ├── test_push_error_matches_schema.py # push error schema +│ └── test_push_output_matches_schema.py # push success schema +└── live/ # Opt-in live tests (markers) + ├── test_push_live.py # requires MLXK2_LIVE_PUSH, HF_TOKEN + └── test_list_human_live.py # requires HF_HOME ``` -Note: This tree is illustrative (not exhaustive). Push-related tests are documented in the dedicated "Push Testing (2.0)" section below to avoid drift. +Note: Live tests are opt-in via markers (`-m live_push`, `-m live_list`) and environment. Default `pytest` discovery runs only the offline suite above. ## Push Testing (2.0) @@ -76,7 +90,7 @@ This section summarizes what our test suite covers for the experimental `push` f - Args: - `--private` (required in alpha): Safety gate to avoid public uploads. - `--create`: Create the repository if it does not exist (model repo). - - `--branch`: Target branch, default `main`. +- `--branch`: Target branch, default `main`. Missing branches are tolerated; with `--create`, the branch is proactively created (and upload retried once if the hub initially rejects the revision). - `--commit`: Commit message, default `"mlx-knife push"`. - `--check-only`: Analyze workspace locally; no network call; returns `data.workspace_health`. - `--dry-run`: Compare local workspace to the remote branch and summarize changes without uploading (requires repo read access). @@ -118,6 +132,7 @@ Notes on output verbosity and behavior - Human mode is chatty by default: progress + one‑liner summary. `--verbose` appends the commit URL when present. - No‑changes detection: If the hub reports “No files have been modified… Skipping to prevent empty commit.”, JSON sets `no_changes: true`, `uploaded_files_count: 0`, and nulls `commit_sha`/`commit_url`. Human shows “— no changes”. This hub signal is preferred over inferring from file lists. - `--dry-run` human output: prints a concise plan line `dry-run: +A ~M -D` (modifications are an approximation and may be `~?` in rare cases). + - Branch creation with `--create`: Even if the push is a no‑op, the target branch is created upfront. Examples (expected) - No‑op re‑push (JSON): `commit_sha: null`, `commit_url: null`, `uploaded_files_count: 0`, `no_changes: true`, `message` mirrors hub text, `hf_logs` contains hub lines. @@ -198,18 +213,19 @@ Spec/Schema - **Schema shape:** Push success/error outputs validate against `docs/json-api-schema.json`. - **No-op push:** Detects `no_changes: true`, sets `uploaded_files_count: 0`, carries hub message into JSON (`message`/`hf_logs`), and human output shows "no changes" without duplicate logs. - **Commit path:** Extracts `commit_sha`, `commit_url`, `change_summary` (+/~/−), correct `uploaded_files_count`; human `--verbose` includes URL. -- **Repo/Branch handling:** Missing repo requires `--create`; with `--create` sets `created_repo: true`. Missing branch is tolerated; upload creates it. +- **Repo/Branch handling:** Missing repo requires `--create`; with `--create` sets `created_repo: true`. Missing branch is tolerated; upload attempts proceed. With `--create`, the branch is proactively created and the upload is retried once if the hub rejects the revision (e.g., “Invalid rev id”). - **Ignore rules:** `.hfignore` is merged with default ignores and forwarded to the hub. Files: - `tests_2.0/test_cli_push_args.py` (CLI errors and JSON outputs) -- `tests_2.0/test_push_extended.py` (no-op vs commit, branch/repo, .hfignore, human) +- `tests_2.0/test_push_extended.py` (no-op vs commit, branch/repo, .hfignore, human; includes retry on invalid revision with `--create`) - `tests_2.0/spec/test_push_output_matches_schema.py` (schema success path) Run (venv39): - `source venv39/bin/activate && pip install -e .` - `pytest -q tests_2.0/test_cli_push_args.py tests_2.0/test_push_extended.py` - `pytest -q tests_2.0/spec/test_push_output_matches_schema.py` +- Targeted retry test: `pytest -q tests_2.0/test_push_extended.py::test_push_retry_creates_branch_on_upload_revision_error` **Live (opt-in / wet)** - Purpose: sanity-check real HF behavior (auth, no-op vs commit, URLs). @@ -282,7 +298,7 @@ Notes - Not part of the 2.0 default run; execute explicitly with `pytest tests/ -v`. - Contains extensive integration/server tests unrelated to the 2.0 JSON CLI. -## 3-Category Test Strategy (MLX Knife 1.1.0+) +## Legacy 1.x: 3-Category Test Strategy (main) MLX Knife uses a **3-category test strategy** to balance test isolation, performance, and user cache protection: @@ -722,21 +738,20 @@ When submitting PRs, please include: - Python version - Which model(s) you tested with -2. **Test results summary**: - ``` - Platform: macOS 14.5, M2 Pro - Python: 3.11.6 - Model: Phi-3-mini-4k-instruct-4bit - Results: 150/150 tests passed - ``` +2. **Test results summary (2.0)**: + ``` + Platform: macOS 14.5, M2 Pro + Python: 3.11.6 + Results: 98/98 tests passed; 9 skipped (opt-in) + ``` 3. **Any issues encountered** and how you resolved them ## Summary -**MLX Knife 1.1.0 STABLE Testing Status:** +**Legacy 1.x Testing Status (main):** -✅ **Production Ready** - 150/150 tests passing +✅ **Stable** - 150/150 tests passing ✅ **Isolated Test System** - User cache stays pristine with temp cache isolation ✅ **3-Category Strategy** - Optimized for performance and safety ✅ **Multi-Python Support** - Python 3.9-3.13 verified @@ -748,11 +763,11 @@ When submitting PRs, please include: ✅ **LibreSSL Warning Fix** - Issue #22: macOS Python 3.9 warning suppression ✅ **Lock Cleanup Fix** - Issue #23: Enhanced rm command with lock cleanup -This comprehensive testing framework validates MLX Knife's **production readiness** through isolated testing with automatic model downloads and separate real MLX validation. +This testing framework validates MLX Knife's stability through isolated testing with automatic model downloads and separate real MLX validation. -## Server-Based Testing (Advanced) +## Server-Based Testing (Legacy 1.x; 2.0 RC planned) -Some tests require a running MLX Knife server with loaded models. These tests are marked with `@pytest.mark.server` and are **not run by default** with `pytest`. +In 1.x (main), some tests require a running MLX Knife server with loaded models and are marked with `@pytest.mark.server`. For 2.0, server/run will return in the RC; until then, server tests are legacy-only. ### Why Separate Server Tests? diff --git a/mlx_knife/__init__.py b/mlx_knife/__init__.py deleted file mode 100644 index c753355..0000000 --- a/mlx_knife/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -"""MLX Knife - HuggingFace-style cache management for MLX models. - -A lightweight, ollama-like CLI for managing and running MLX models on Apple Silicon. -Provides native MLX execution with streaming output and interactive chat capabilities. -""" - -# Suppress urllib3 LibreSSL warning on macOS system Python 3.9 (must be before any imports that use urllib3) -import warnings - -warnings.filterwarnings('ignore', message='urllib3 v2 only supports OpenSSL 1.1.1+') - -__version__ = "1.1.0" -__author__ = "The BROKE team" -__email__ = "broke@gmx.eu" -__license__ = "MIT" -__description__ = "ollama-style CLI for MLX models on Apple Silicon" -__url__ = "https://github.com/mzau/mlx-knife" - -# Version tuple for programmatic access (major, minor, patch) -VERSION = (1, 1, 0) - -# Core functionality imports -from .cache_utils import ( - check_all_models_health, - check_model_health, - list_models, - rm_model, - show_model, -) -from .hf_download import pull_model -from .mlx_runner import MLXRunner - -__all__ = [ - "__version__", - "list_models", - "show_model", - "check_model_health", - "check_all_models_health", - "rm_model", - "pull_model", - "MLXRunner", -] diff --git a/mlx_knife/cache_utils.py b/mlx_knife/cache_utils.py deleted file mode 100644 index 870dfd3..0000000 --- a/mlx_knife/cache_utils.py +++ /dev/null @@ -1,904 +0,0 @@ -# mlx_knife/cache_utils.py - -import datetime -import json -import os -import shutil -import sys -from pathlib import Path - -DEFAULT_CACHE_ROOT = Path.home() / ".cache/huggingface" -CACHE_ROOT = Path(os.environ.get("HF_HOME", DEFAULT_CACHE_ROOT)) -MODEL_CACHE = CACHE_ROOT / "hub" - -# Global variable to track if warning was shown -_legacy_warning_shown = False - -# Check for models in legacy location and warn user -_legacy_models = list(CACHE_ROOT.glob("models--*")) -_is_test_env = "test_cache" in str(CACHE_ROOT) or "PYTEST_CURRENT_TEST" in os.environ -if _legacy_models and not _legacy_warning_shown and not _is_test_env: - print(f"\n⚠️ Found {len(_legacy_models)} models in legacy location: {CACHE_ROOT}") - print(f" Please move them to: {MODEL_CACHE}") - print(f" Command: mv {CACHE_ROOT}/models--* {MODEL_CACHE}/") - print(" This warning will appear until models are moved.\n") - _legacy_warning_shown = True - - -def hf_to_cache_dir(hf_name: str) -> str: - if hf_name.startswith("models--"): - return hf_name - if "/" in hf_name: - org, model = hf_name.split("/", 1) - return f"models--{org}--{model}" - else: - return f"models--{hf_name}" - -def cache_dir_to_hf(cache_name: str) -> str: - if cache_name.startswith("models--"): - remaining = cache_name[len("models--"):] - if "--" in remaining: - parts = remaining.split("--", 1) - return f"{parts[0]}/{parts[1]}" - else: - return remaining - return cache_name - -def expand_model_name(model_name): - if "/" in model_name: - return model_name - mlx_candidate = f"mlx-community/{model_name}" - mlx_cache_dir = MODEL_CACHE / hf_to_cache_dir(mlx_candidate) - if mlx_cache_dir.exists(): - return mlx_candidate - common_mlx_patterns = [ - "Llama-", "Qwen", "Mistral", "Phi-", "Mixtral", "phi-", "deepseek" - ] - for pattern in common_mlx_patterns: - if pattern in model_name: - return f"mlx-community/{model_name}" - return model_name - -def find_matching_models(pattern): - """Find models that match a partial pattern. Returns a list of (model_dir, hf_name) tuples.""" - all_models = [d for d in MODEL_CACHE.iterdir() if d.name.startswith("models--")] - matches = [] - - for model_dir in all_models: - hf_name = cache_dir_to_hf(model_dir.name) - # Check if the pattern appears in the model name (case insensitive) - if pattern.lower() in hf_name.lower(): - matches.append((model_dir, hf_name)) - - return matches - -def hash_exists_in_local_cache(model_name, commit_hash): - """Check if a specific commit hash exists in the local cache for a model. - - Supports both full hashes and short hash prefixes (local resolution only). - - Args: - model_name: Full model name (e.g., 'mlx-community/Phi-3-mini-4k-instruct-4bit') - commit_hash: Commit hash to check for (short or full) - - Returns: - Full hash if exists in local cache, None otherwise - """ - base_cache_dir = MODEL_CACHE / hf_to_cache_dir(model_name) - if not base_cache_dir.exists(): - return None - - snapshots_dir = base_cache_dir / "snapshots" - if not snapshots_dir.exists(): - return None - - # Check for exact match first (full hash) - hash_dir = snapshots_dir / commit_hash - if hash_dir.exists(): - return commit_hash - - # Check for short hash match (local resolution) - if len(commit_hash) < 40: - for snapshot_dir in snapshots_dir.iterdir(): - if snapshot_dir.is_dir() and snapshot_dir.name.startswith(commit_hash): - return snapshot_dir.name # Return full hash - - return None - -def resolve_single_model(model_spec): - """ - Resolve a model spec to a single model, supporting fuzzy matching. - Returns (model_path, model_name, commit_hash) or (None, None, None) if failed. - Prints appropriate error messages for ambiguous matches. - """ - # Parse the model spec (handles @commit_hash syntax) - model_name, commit_hash = parse_model_spec(model_spec) - - # Try exact match first - base_cache_dir = MODEL_CACHE / hf_to_cache_dir(model_name) - if base_cache_dir.exists(): - return get_model_path(model_spec) - - # Extract the base name (without @commit_hash) for fuzzy matching - base_spec = model_spec.split('@')[0] if '@' in model_spec else model_spec - - # Try fuzzy matching - matches = find_matching_models(base_spec) - - if not matches: - print(f"No models found matching '{base_spec}'!") - return None, None, None - elif len(matches) == 1: - # Unambiguous match - use the found model with the original commit hash (if any) - found_model_dir, found_hf_name = matches[0] - if commit_hash: - resolved_spec = f"{found_hf_name}@{commit_hash}" - else: - resolved_spec = found_hf_name - return get_model_path(resolved_spec) - elif len(matches) > 1 and commit_hash: - # Issue #13: Hash-based disambiguation for ambiguous model names - for _model_dir, hf_name in matches: - resolved_hash = hash_exists_in_local_cache(hf_name, commit_hash) - if resolved_hash: - resolved_spec = f"{hf_name}@{resolved_hash}" - return get_model_path(resolved_spec) - - # Hash not found in any candidate model - print(f"Hash '{commit_hash}' not found in any model matching '{base_spec}'") - print("Available models:") - for _, hf_name in sorted(matches, key=lambda x: x[1]): - print(f" {hf_name}") - return None, None, None - else: - # Multiple matches without hash - show error with suggestions - print(f"Multiple models match '{base_spec}'. Please be more specific:") - for _, hf_name in sorted(matches, key=lambda x: x[1]): - print(f" {hf_name}") - return None, None, None - -def get_model_path(model_spec): - model_name, commit_hash = parse_model_spec(model_spec) - base_cache_dir = MODEL_CACHE / hf_to_cache_dir(model_name) - if not base_cache_dir.exists(): - return None, model_name, commit_hash - if commit_hash: - hash_dir = base_cache_dir / "snapshots" / commit_hash - if hash_dir.exists(): - return hash_dir, model_name, commit_hash - else: - return None, model_name, commit_hash - snapshots_dir = base_cache_dir / "snapshots" - if snapshots_dir.exists(): - snapshots = [d for d in snapshots_dir.iterdir() if d.is_dir()] - if snapshots: - latest = max(snapshots, key=lambda x: x.stat().st_mtime) - return latest, model_name, latest.name - # Return base_cache_dir for corrupted models so rm_model can handle them - return base_cache_dir, model_name, commit_hash - -def parse_model_spec(model_spec): - if "@" in model_spec: - model_name, commit_hash = model_spec.rsplit("@", 1) - model_name = expand_model_name(model_name) - return model_name, commit_hash - model_name = expand_model_name(model_spec) - return model_name, None - -def get_model_size(model_path): - if not model_path.exists(): - return "?" - total_size = 0 - for file in model_path.rglob("*"): - if file.is_file(): - total_size += file.stat().st_size - if total_size >= 1_000_000_000: - return f"{total_size / 1_000_000_000:.1f} GB" - elif total_size >= 1_000_000: - return f"{total_size / 1_000_000:.1f} MB" - else: - return f"{total_size / 1_000:.1f} KB" - -def get_model_modified(model_path): - if not model_path.exists(): - return "?" - mtime = model_path.stat().st_mtime - now = datetime.datetime.now() - modified = datetime.datetime.fromtimestamp(mtime) - diff = now - modified - if diff.days > 0: - return f"{diff.days} days ago" - elif diff.seconds > 3600: - hours = diff.seconds // 3600 - return f"{hours} hours ago" - else: - minutes = diff.seconds // 60 - return f"{minutes} minutes ago" - -def detect_framework(model_path, hf_name): - if "mlx-community" in hf_name: - return "MLX" - snapshots_dir = model_path / "snapshots" - if not snapshots_dir.exists(): - return "Unknown" - has_safetensors = any(snapshots_dir.glob("*/*.safetensors")) - has_pytorch_bin = any(snapshots_dir.glob("*/pytorch_model.bin")) - has_config = any(snapshots_dir.glob("*/config.json")) - total_size = get_model_size(model_path) - try: - size_mb = float(total_size.replace(" GB", "000").replace(" MB", "").replace(" KB", "0").replace(" ", "")) - except: - size_mb = 0 - if size_mb < 10: - return "Tokenizer" - elif has_safetensors and has_config: - return "PyTorch" - elif has_pytorch_bin: - return "PyTorch" - else: - return "Unknown" - -def get_model_hash(model_path): - snapshots_dir = model_path / "snapshots" - if not snapshots_dir.exists(): - return "--------" - snapshots = [d for d in snapshots_dir.iterdir() if d.is_dir()] - if not snapshots: - return "--------" - latest = max(snapshots, key=lambda x: x.stat().st_mtime) - return latest.name[:8] - -def is_model_healthy(model_spec): - model_path, _, _ = resolve_single_model(model_spec) - if not model_path: - return False - config_path = model_path / "config.json" - if not config_path.exists(): - return False - # Check if config.json is valid JSON and not empty - try: - with open(config_path) as f: - config_data = json.load(f) - # Basic sanity check: should be a non-empty dict - if not isinstance(config_data, dict) or len(config_data) == 0: - return False - except (OSError, json.JSONDecodeError): - return False - weight_files = list(model_path.glob("*.safetensors")) + list(model_path.glob("*.bin")) + list(model_path.glob("*.gguf")) - if not weight_files: - weight_files = list(model_path.glob("**/*.safetensors")) + list(model_path.glob("**/*.bin")) + list(model_path.glob("**/*.gguf")) - if not weight_files: - index_file = model_path / "model.safetensors.index.json" - if index_file.exists(): - try: - with open(index_file) as f: - index = json.load(f) - if 'weight_map' in index: - referenced_files = set(index['weight_map'].values()) - existing_files = [f for f in referenced_files if (model_path / f).exists()] - if len(existing_files) > 0: - return True - except: - pass - if not weight_files: - return False - lfs_ok, _ = check_lfs_corruption(model_path) - if not lfs_ok: - return False - return True - -def check_lfs_corruption(model_path): - corrupted_files = [] - for file_path in model_path.glob("*"): - if file_path.is_file() and file_path.stat().st_size < 200: - try: - with open(file_path, 'rb') as f: - header = f.read(100) - if b'version https://git-lfs.github.com/spec/v1' in header: - corrupted_files.append(file_path.name) - except: - pass - if corrupted_files: - return False, f"LFS pointers instead of files: {', '.join(corrupted_files)}" - return True, "No LFS corruption detected" - -def check_model_health(model_spec): - model_path, model_name, commit_hash = resolve_single_model(model_spec) - if not model_path: - # resolve_single_model already printed the appropriate error message - return False - - print(f"Checking model: {model_name}") - if commit_hash: - print(f"Hash: {commit_hash}") - - # Use the robust health check - if is_model_healthy(model_spec): - print("\n[OK] Model is healthy and usable!") - return True - else: - # Detailed diagnosis for WHY it's unhealthy - print("\n[ERROR] Model is corrupted. Detailed diagnosis:") - - # Check config.json - config_path = model_path / "config.json" - if not config_path.exists(): - print(" - config.json missing") - else: - try: - with open(config_path) as f: - config_data = json.load(f) - if not isinstance(config_data, dict) or len(config_data) == 0: - print(" - config.json is empty or invalid") - else: - print(" - config.json found and valid") - except (OSError, json.JSONDecodeError): - print(" - config.json exists but contains invalid JSON") - - # Check weight files (including gguf support like is_model_healthy) - weight_files = list(model_path.glob("*.safetensors")) + list(model_path.glob("*.bin")) + list(model_path.glob("*.gguf")) - if not weight_files: - weight_files = list(model_path.glob("**/*.safetensors")) + list(model_path.glob("**/*.bin")) + list(model_path.glob("**/*.gguf")) - - if weight_files: - total_size = sum(f.stat().st_size for f in weight_files) - size_mb = total_size / (1024 * 1024) - print(f" - Model weights found ({len(weight_files)} files, {size_mb:.1f}MB)") - elif (model_path / "model.safetensors.index.json").exists(): - # Check multi-file model - try: - with open(model_path / "model.safetensors.index.json") as f: - index = json.load(f) - if 'weight_map' in index: - referenced_files = set(index['weight_map'].values()) - existing_files = [f for f in referenced_files if (model_path / f).exists()] - if existing_files: - total_size = sum((model_path / f).stat().st_size for f in existing_files) - size_mb = total_size / (1024 * 1024) - print(f" - Multi-file weights ({len(existing_files)}/{len(referenced_files)} files, {size_mb:.1f}MB)") - if len(existing_files) < len(referenced_files): - print(" - Incomplete multi-file model") - else: - print(" - Multi-file model index found but no weight files exist") - else: - print(" - Multi-file model index is invalid") - except Exception as e: - print(f" - Multi-file model index error: {e}") - else: - print(" - No model weights found (.safetensors, .bin, .gguf)") - - # Check LFS corruption - lfs_ok, lfs_msg = check_lfs_corruption(model_path) - if not lfs_ok: - print(f" - {lfs_msg}") - else: - print(f" - {lfs_msg}") - - # Show framework - framework = detect_framework(model_path.parent.parent, model_name) - print(f" - Framework: {framework}") - - # Offer deletion for corrupted models - confirm = input("\nModel appears corrupted. Delete? [y/N] ") - if confirm.lower() == "y": - import errno - import shutil - try: - if commit_hash: - # Delete specific hash/snapshot - shutil.rmtree(model_path) - print(f"Hash {commit_hash} deleted.") - else: - # Delete entire model directory (go up from snapshots or use base_cache_dir) - if model_path.name.startswith("models--"): - # model_path is base_cache_dir (corrupted model case) - shutil.rmtree(model_path) - else: - # model_path is snapshot dir - model_base_dir = model_path.parent.parent - shutil.rmtree(model_base_dir) - print(f"Model {model_name} deleted.") - except PermissionError as e: - print(f"[ERROR] Permission denied: Cannot delete {e.filename}") - print(" Try running with appropriate permissions or manually delete the directory.") - except OSError as e: - if e.errno == errno.ENOTEMPTY: - print(f"[ERROR] Directory not empty: {e.filename}") - print(" Another process may be using this model.") - elif e.errno == errno.EACCES: - print(f"[ERROR] Access denied: {e.filename}") - else: - print(f"[ERROR] OS Error while deleting: {e}") - except Exception as e: - print(f"[ERROR] Unexpected error while deleting: {type(e).__name__}: {e}") - - return False - -def check_all_models_health(): - models = [d for d in MODEL_CACHE.iterdir() if d.name.startswith("models--")] - if not models: - print("No models found in HuggingFace cache.") - return - print(f"Checking {len(models)} models for integrity...\n") - healthy_models = [] - problematic_models = [] - for model_dir in sorted(models, key=lambda x: x.stat().st_mtime, reverse=True): - hf_name = cache_dir_to_hf(model_dir.name) - model_hash = get_model_hash(model_dir) - print(f"{hf_name} ({model_hash})") - if is_model_healthy(hf_name): - healthy_models.append((hf_name, model_hash)) - print(" [OK] Healthy\n") - else: - problematic_models.append((hf_name, model_hash)) - print(" [ERROR] Problematic\n") - print("=" * 50) - print("Summary:") - print(f"[OK] Healthy models: {len(healthy_models)}") - print(f"[ERROR] Problematic models: {len(problematic_models)}") - if problematic_models: - print("\n[WARNING] Problematic models:") - for name, hash_id in problematic_models: - print(f" - {name} ({hash_id})") - print("\nRepair tips:") - print(" python mlx_knife.cli pull # Re-download") - print(" python mlx_knife.cli rm # Delete") - print(" python mlx_knife.cli health # Show details") - return len(problematic_models) == 0 - -def list_models(show_all=False, framework_filter=None, show_health=False, single_model=None, verbose=False): - if single_model: - # Try exact match first - expanded_model = expand_model_name(single_model) - model_dir = MODEL_CACHE / hf_to_cache_dir(expanded_model) - - if model_dir.exists(): - models = [model_dir] - else: - # If exact match fails, do partial name matching - if not MODEL_CACHE.exists(): - print(f"No models found matching '{single_model}' - cache directory doesn't exist yet.") - print("Use 'mlxk pull ' to download models first.") - return - all_models = [d for d in MODEL_CACHE.iterdir() if d.name.startswith("models--")] - matching_models = [] - - for model_dir in all_models: - hf_name = cache_dir_to_hf(model_dir.name) - # Check if the pattern appears in the model name (case insensitive) - if single_model.lower() in hf_name.lower(): - matching_models.append(model_dir) - - if not matching_models: - print(f"No models found matching '{single_model}'!") - return - - models = matching_models - else: - if not MODEL_CACHE.exists(): - print("No models found - cache directory doesn't exist yet.") - print("Use 'mlxk pull ' to download models first.") - return - models = [d for d in MODEL_CACHE.iterdir() if d.name.startswith("models--")] - if not models: - print("No models found in HuggingFace cache.") - return - if show_health: - if show_all: - print(f"{'NAME':<40} {'ID':<10} {'SIZE':<10} {'MODIFIED':<15} {'FRAMEWORK':<10} {'HEALTH':<8}") - else: - print(f"{'NAME':<40} {'ID':<10} {'SIZE':<10} {'MODIFIED':<15} {'HEALTH':<8}") - else: - if show_all: - print(f"{'NAME':<40} {'ID':<10} {'SIZE':<10} {'MODIFIED':<15} {'FRAMEWORK':<10}") - else: - print(f"{'NAME':<40} {'ID':<10} {'SIZE':<10} {'MODIFIED':<15}") - for m in sorted(models, key=lambda x: x.stat().st_mtime, reverse=True): - hf_name = cache_dir_to_hf(m.name) - size = get_model_size(m) - modified = get_model_modified(m) - model_hash = get_model_hash(m) - framework = detect_framework(m, hf_name) - if framework_filter and framework.lower() != framework_filter: - continue - if not show_all and not framework_filter and framework != "MLX": - continue - # Handle display name based on verbose flag - display_name = hf_name - if hf_name.startswith("mlx-community/") and not verbose: - # For MLX models, hide prefix unless verbose is set - display_name = hf_name[len("mlx-community/"):] - health_status = "" - if show_health: - health_status = "[OK]" if is_model_healthy(hf_name) else "[ERR]" - if show_all: - print(f"{display_name:<40} {model_hash:<10} {size:<10} {modified:<15} {framework:<10} {health_status:<8}") - else: - print(f"{display_name:<40} {model_hash:<10} {size:<10} {modified:<15} {health_status:<8}") - else: - if show_all: - print(f"{display_name:<40} {model_hash:<10} {size:<10} {modified:<15} {framework:<10}") - else: - print(f"{display_name:<40} {model_hash:<10} {size:<10} {modified:<15}") - -def run_model(model_spec, prompt=None, interactive=False, temperature=0.7, - max_tokens=500, top_p=0.9, repetition_penalty=1.1, stream=True, - use_chat_template=True, verbose=False): - """Run an MLX model with enhanced features. - - Args: - model_spec: Model specification (name[@hash]) - prompt: Input prompt (if None and not interactive, enters interactive mode) - interactive: Force interactive mode - temperature: Sampling temperature - max_tokens: Maximum tokens to generate - top_p: Top-p sampling parameter - repetition_penalty: Penalty for repeated tokens - stream: Whether to stream output - """ - model_path, model_name, commit_hash = resolve_single_model(model_spec) - if not model_path: - print(f"Use: mlxk pull {model_spec}") - sys.exit(1) - - framework = detect_framework(model_path.parent.parent, model_name) - if framework != "MLX": - print(f"Model {model_name} is not MLX-compatible (Framework: {framework})!") - print("Use MLX-Community models: https://huggingface.co/mlx-community") - sys.exit(1) - - # Try to use the enhanced runner - try: - from .mlx_runner import run_model_enhanced - - run_model_enhanced( - model_path=str(model_path), - prompt=prompt, - interactive=interactive, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - stream=stream, - use_chat_template=use_chat_template, - verbose=verbose, - ) - except ImportError: - # Fallback to subprocess if mlx_runner is not available - print("[WARNING] Enhanced runner not available, falling back to subprocess mode") - print(f"Running model: {model_name}") - if commit_hash: - print(f"Hash: {commit_hash}") - print(f"Cache path: {model_path}") - - if interactive or prompt is None: - print("Interactive mode not supported in fallback mode") - prompt = prompt or "Hello" - - print(f"Prompt: {prompt}\n") - os.system(f'python -m mlx_lm generate --model "{model_path}" --prompt "{prompt}"') - -def show_model(model_spec, show_files=False, show_config=False): - """Show detailed information about a specific model.""" - model_path, model_name, commit_hash = resolve_single_model(model_spec) - - if not model_path: - return False - - # Basic information - print(f"Model: {model_name}") - print(f"Path: {model_path}") - - if commit_hash: - print(f"Snapshot: {commit_hash}") - else: - # Show current snapshot hash - current_hash = model_path.name - print(f"Snapshot: {current_hash}") - - # Size - size = get_model_size(model_path) - print(f"Size: {size}") - - # Modified time - modified = get_model_modified(model_path) - print(f"Modified: {modified}") - - # Framework - framework = detect_framework(model_path.parent.parent, model_name) - print(f"Framework: {framework}") - - # Quantization and Precision info - config_path = model_path / "config.json" - quantization_info = None - precision_info = None - gguf_variants = [] - - if config_path.exists(): - try: - with open(config_path) as f: - config_data = json.load(f) - - # 1. Check for explicit quantization field (MLX style) - if "quantization" in config_data and isinstance(config_data["quantization"], dict): - quant = config_data["quantization"] - if "bits" in quant: - quantization_info = f"{quant['bits']}-bit" - precision_info = f"int{quant['bits']}" - if "group_size" in quant: - quantization_info += f" (group_size: {quant['group_size']})" - - # 2. Check torch_dtype (HuggingFace standard) - elif "torch_dtype" in config_data: - dtype = config_data["torch_dtype"] - precision_info = dtype - # Check if model name suggests quantization - name_lower = model_name.lower() - if "4bit" in name_lower or "-4b" in name_lower: - quantization_info = "4-bit (inferred from name)" - elif "8bit" in name_lower or "-8b" in name_lower: - quantization_info = "8-bit (inferred from name)" - else: - quantization_info = "No quantization detected" - - # 3. Special handling for GGUF files - gguf_files = sorted(list(model_path.glob("*.gguf"))) - if gguf_files and not quantization_info: - # Collect all GGUF variants - gguf_variants = [] - for f in gguf_files: - name = f.name - size_mb = f.stat().st_size / (1024 * 1024) - - # Parse quantization type from filename - name_lower = name.lower() - if "q2_k" in name_lower: - variant_info = f"Q2_K (2-bit, {size_mb:.0f} MB)" - elif "q3_k_s" in name_lower: - variant_info = f"Q3_K_S (3-bit small, {size_mb:.0f} MB)" - elif "q3_k_m" in name_lower: - variant_info = f"Q3_K_M (3-bit medium, {size_mb:.0f} MB)" - elif "q3_k_l" in name_lower: - variant_info = f"Q3_K_L (3-bit large, {size_mb:.0f} MB)" - elif "q3_k" in name_lower: - variant_info = f"Q3_K (3-bit, {size_mb:.0f} MB)" - elif "q4_0" in name_lower: - variant_info = f"Q4_0 (4-bit, {size_mb:.0f} MB)" - elif "q4_k_s" in name_lower: - variant_info = f"Q4_K_S (4-bit small, {size_mb:.0f} MB)" - elif "q4_k_m" in name_lower: - variant_info = f"Q4_K_M (4-bit medium, {size_mb:.0f} MB)" - elif "q4_k" in name_lower: - variant_info = f"Q4_K (4-bit, {size_mb:.0f} MB)" - elif "q5_0" in name_lower: - variant_info = f"Q5_0 (5-bit, {size_mb:.0f} MB)" - elif "q5_k_s" in name_lower: - variant_info = f"Q5_K_S (5-bit small, {size_mb:.0f} MB)" - elif "q5_k_m" in name_lower: - variant_info = f"Q5_K_M (5-bit medium, {size_mb:.0f} MB)" - elif "q5_k" in name_lower: - variant_info = f"Q5_K (5-bit, {size_mb:.0f} MB)" - elif "q6_k" in name_lower: - variant_info = f"Q6_K (6-bit, {size_mb:.0f} MB)" - elif "q8_0" in name_lower: - variant_info = f"Q8_0 (8-bit, {size_mb:.0f} MB)" - else: - variant_info = f"{name} ({size_mb:.0f} MB)" - - gguf_variants.append(variant_info) - - if len(gguf_variants) > 1: - quantization_info = "Multiple GGUF variants available" - precision_info = "gguf (see variants below)" - elif len(gguf_variants) == 1: - quantization_info = gguf_variants[0].split(' (')[0] - precision_info = "gguf" - else: - quantization_info = "GGUF format (quantization unknown)" - precision_info = "gguf" - - except (OSError, json.JSONDecodeError, KeyError): - pass - - # Display quantization and precision info - if quantization_info: - print(f"Quantization: {quantization_info}") - else: - print("Quantization: Unknown (no info in config)") - - if precision_info: - print(f"Precision: {precision_info}") - else: - print("Precision: Unknown") - - # Display GGUF variants if available - if gguf_variants and len(gguf_variants) > 1: - print("\nAvailable GGUF variants:") - for variant in gguf_variants: - print(f" - {variant}") - - # Health status - health_ok = is_model_healthy(model_name) - if health_ok: - print("Health: [OK]") - else: - print("Health: [ERROR] CORRUPTED") - # Check specific issues - issues = [] - if not (model_path / "config.json").exists(): - issues.append("config.json missing") - - weight_files = list(model_path.glob("*.safetensors")) + list(model_path.glob("*.bin")) + list(model_path.glob("*.gguf")) - if not weight_files: - weight_files = list(model_path.glob("**/*.safetensors")) + list(model_path.glob("**/*.bin")) + list(model_path.glob("**/*.gguf")) - if not weight_files: - index_file = model_path / "model.safetensors.index.json" - if not index_file.exists(): - issues.append("No model weights found") - - lfs_ok, lfs_msg = check_lfs_corruption(model_path) - if not lfs_ok: - issues.append(lfs_msg) - - if issues: - print(" Issues:") - for issue in issues: - print(f" - {issue}") - - # Show files if requested - if show_files: - print("\nFiles:") - files = [] - for file in sorted(model_path.rglob("*")): - if file.is_file(): - relative_path = file.relative_to(model_path) - file_size = file.stat().st_size - if file_size >= 1_000_000_000: - size_str = f"{file_size / 1_000_000_000:.2f} GB" - elif file_size >= 1_000_000: - size_str = f"{file_size / 1_000_000:.2f} MB" - elif file_size >= 1_000: - size_str = f"{file_size / 1_000:.2f} KB" - else: - size_str = f"{file_size} B" - files.append((str(relative_path), size_str)) - - # Print files in a nice table format - if files: - max_name_len = max(len(f[0]) for f in files) - for file_path, file_size in files: - print(f" {file_path:<{max_name_len}} {file_size:>10}") - else: - print(" No files found") - - # Show config if requested - if show_config: - config_path = model_path / "config.json" - if config_path.exists(): - print("\nConfig:") - try: - with open(config_path) as f: - config_data = json.load(f) - print(json.dumps(config_data, indent=2)) - except Exception as e: - print(f" Error reading config: {e}") - else: - print("\nConfig: Not found") - - return True - -def rm_model(model_spec, force=False): - original_spec = model_spec - - # First try to resolve using fuzzy matching - resolved_path, resolved_name, resolved_hash = resolve_single_model(model_spec) - - if not resolved_path: - # resolve_single_model already printed the error message for most cases - # But ensure we always provide feedback to the user - print(f"Model '{original_spec}' not found or corrupted.") - return - - # Use the resolved model name for deletion - model_name = resolved_name - commit_hash = resolved_hash - - - # Confirm on auto-expansion (if the resolved name is different from input) - base_spec = original_spec.split("@")[0] if "@" in original_spec else original_spec - if base_spec != model_name and "/" not in base_spec: - confirm = input(f"Delete '{model_name}' (matched from '{base_spec}')? [Y/n] ") - if confirm.lower() == "n": - print("Delete aborted.") - return - - base_cache_dir = MODEL_CACHE / hf_to_cache_dir(model_name) - # This should exist since resolve_single_model succeeded, but double-check - if not base_cache_dir.exists(): - print(f"[ERROR] Model directory disappeared: {model_name}") - return - # Specific hash to delete? - if commit_hash: - hash_dir = base_cache_dir / "snapshots" / commit_hash - if not hash_dir.exists(): - print(f"Hash {commit_hash} for model {model_name} not found!") - print("\nAvailable hashes:") - snapshots_dir = base_cache_dir / "snapshots" - if snapshots_dir.exists(): - for snapshot in sorted(snapshots_dir.iterdir()): - if snapshot.is_dir(): - print(f" {snapshot.name[:8]}") - return - if force: - confirm_delete = True - else: - confirm = input(f"Delete hash {commit_hash} of model {model_name}? [y/N] ") - confirm_delete = confirm.lower() == "y" - - if confirm_delete: - # Issue #23 Fix: Delete entire model directory, not just the snapshot - # This prevents the double-execution problem where refs/ remain intact - shutil.rmtree(base_cache_dir) - print(f"{model_name}@{commit_hash} deleted") - - # Clean up associated lock files - try: - _cleanup_model_locks(model_name, force) - except Exception as e: - print(f"Warning: Could not clean up cache files: {e}") - else: - print("Aborted.") - else: - # Delete entire model - if force: - confirm_delete = True - else: - confirm = input(f"Delete entire model {model_name} ({base_cache_dir})? [y/N] ") - confirm_delete = confirm.lower() == "y" - - if confirm_delete: - shutil.rmtree(base_cache_dir) - print(f"Model {model_name} completely deleted.") - - # Clean up associated lock files - try: - _cleanup_model_locks(model_name, force) - except Exception as e: - print(f"Warning: Could not clean up cache files: {e}") - else: - print("Aborted.") - - -def _cleanup_model_locks(model_name, force=False): - """Clean up HuggingFace lock files for a deleted model. - - Args: - model_name: The model name (e.g. 'microsoft/DialoGPT-small') - force: If True, delete without asking. If False, prompt user. - """ - locks_dir = MODEL_CACHE / ".locks" / hf_to_cache_dir(model_name) - - if not locks_dir.exists(): - return # No locks to clean up - - # Count lock files - try: - lock_files = list(locks_dir.iterdir()) - if not lock_files: - return # Empty directory - - if force: - # Delete without asking - shutil.rmtree(locks_dir) - print(f"Cleaned up cache files ({len(lock_files)} files).") - else: - # Ask user - confirm = input("Clean up cache files? [Y/n] ") - if confirm.lower() != "n": - shutil.rmtree(locks_dir) - print(f"Cache files cleaned up ({len(lock_files)} files).") - else: - print("Cache files left intact.") - - except Exception as e: - print(f"Warning: Could not clean up cache files: {e}") diff --git a/mlx_knife/cli.py b/mlx_knife/cli.py deleted file mode 100644 index 594a14f..0000000 --- a/mlx_knife/cli.py +++ /dev/null @@ -1,133 +0,0 @@ -# mlx_knife/cli.py - -import argparse -import sys - -from . import __version__ -from .cache_utils import ( - check_all_models_health, - check_model_health, - list_models, - rm_model, - run_model, - show_model, -) -from .hf_download import pull_model -from .server import run_server - - -def main(): - parser = argparse.ArgumentParser( - description="MLX Knife CLI (HuggingFace-style cache management for MLX models)" - ) - parser.add_argument('--version', action='version', version=f'MLX Knife {__version__}') - subparsers = parser.add_subparsers(dest="cmd") - - # list - list_p = subparsers.add_parser("list", help="List available models in cache") - list_p.add_argument("model", nargs="?", help="Specific model to list (optional)") - list_p.add_argument("--all", action="store_true", help="Show all models (not just MLX)") - list_p.add_argument("--framework", choices=["mlx", "pytorch", "tokenizer"], help="Filter by framework") - list_p.add_argument("--health", action="store_true", help="Show health status") - list_p.add_argument("--verbose", action="store_true", help="Show detailed information (requires model argument)") - - # pull - pull_p = subparsers.add_parser("pull", help="Download a model from HuggingFace") - pull_p.add_argument("model_spec", help="Model[@hash] (e.g. mlx-community/Qwen2.5-0.5B-Instruct-4bit@a5339a41)") - - # run - run_p = subparsers.add_parser("run", help="Run a model with prompt") - run_p.add_argument("model_spec", help="Model[@hash] (e.g. mlx-community/Qwen2.5-0.5B-Instruct-4bit@a5339a41)") - run_p.add_argument("prompt", nargs="?", default=None, help="Prompt text (if not provided, enters interactive mode)") - run_p.add_argument("--interactive", "-i", action="store_true", help="Force interactive dialog mode") - run_p.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature (default: 0.7)") - run_p.add_argument("--max-tokens", type=int, default=None, help="Maximum tokens to generate (default: model context length)") - run_p.add_argument("--top-p", type=float, default=0.9, help="Top-p sampling parameter (default: 0.9)") - run_p.add_argument("--repetition-penalty", type=float, default=1.1, help="Penalty for repeated tokens (default: 1.1)") - run_p.add_argument("--no-stream", action="store_true", help="Disable streaming output") - run_p.add_argument("--no-chat-template", action="store_true", help="Disable chat template formatting (use raw prompt)") - run_p.add_argument("--verbose", "-v", action="store_true", help="Show detailed output (model loading, memory usage, token stats)") - - # rm - rm_p = subparsers.add_parser("rm", help="Delete a model from cache") - rm_p.add_argument("model_spec", help="Model[@hash] (e.g. mlx-community/Qwen2.5-0.5B-Instruct-4bit@a5339a41)") - rm_p.add_argument("--force", action="store_true", help="Skip confirmation and clean up cache files automatically") - - # health - health_p = subparsers.add_parser("health", help="Check model integrity") - health_p.add_argument("model_spec", nargs="?", help="Model[@hash] (optional)") - health_p.add_argument("--all", action="store_true", help="Check all models in cache") - - # show - show_p = subparsers.add_parser("show", help="Show detailed information about a specific model") - show_p.add_argument("model_spec", help="Model[@hash] (e.g. mlx-community/Qwen2.5-0.5B-Instruct-4bit@a5339a41)") - show_p.add_argument("--files", action="store_true", help="List all files and sizes under the model path") - show_p.add_argument("--config", action="store_true", help="Print pretty-formatted config.json") - - # server - server_p = subparsers.add_parser("server", help="Start OpenAI-compatible API server") - server_p.add_argument("--host", default="127.0.0.1", help="Server host (default: 127.0.0.1)") - server_p.add_argument("--port", type=int, default=8000, help="Server port (default: 8000)") - server_p.add_argument("--max-tokens", type=int, default=None, help="Default max tokens for completions (default: model-aware dynamic limits)") - server_p.add_argument("--reload", action="store_true", help="Enable auto-reload for development") - server_p.add_argument("--log-level", default="info", choices=["debug", "info", "warning", "error"], help="Log level (default: info)") - - args = parser.parse_args() - - if args.cmd == "list": - if args.model: - if args.verbose and not args.all and not args.framework and not args.health: - # Show detailed info for a specific model (same as show command) - show_model(args.model) - else: - # Show just the single model row - list_models(show_all=args.all, framework_filter=args.framework, show_health=args.health, single_model=args.model, verbose=args.verbose) - else: - # Normal list behavior - verbose works with MLX models too - list_models(show_all=args.all, framework_filter=args.framework, show_health=args.health, verbose=args.verbose) - elif args.cmd == "pull": - pull_model(args.model_spec) - elif args.cmd == "run": - run_model( - args.model_spec, - prompt=args.prompt, - interactive=args.interactive, - temperature=args.temperature, - max_tokens=args.max_tokens, - top_p=args.top_p, - repetition_penalty=args.repetition_penalty, - stream=not args.no_stream, - use_chat_template=not args.no_chat_template, - verbose=args.verbose - ) - elif args.cmd == "rm": - rm_model(args.model_spec, force=args.force) - elif args.cmd == "health": - if args.model_spec: - check_model_health(args.model_spec) - else: - # Default to checking all models if no specific model is provided - check_all_models_health() - elif args.cmd == "show": - show_model(args.model_spec, show_files=args.files, show_config=args.config) - elif args.cmd == "server": - # Validate server arguments - if args.max_tokens is not None and args.max_tokens <= 0: - print(f"Error: --max-tokens must be positive, got: {args.max_tokens}") - sys.exit(1) - if args.port <= 0 or args.port > 65535: - print(f"Error: --port must be between 1-65535, got: {args.port}") - sys.exit(1) - - run_server( - host=args.host, - port=args.port, - max_tokens=args.max_tokens, - reload=args.reload, - log_level=args.log_level - ) - else: - parser.print_help() - -if __name__ == "__main__": - main() diff --git a/mlx_knife/hf_download.py b/mlx_knife/hf_download.py deleted file mode 100644 index c0aa217..0000000 --- a/mlx_knife/hf_download.py +++ /dev/null @@ -1,141 +0,0 @@ -import json -import os -import subprocess -import sys -import tempfile - -try: - from .cache_utils import ( - MODEL_CACHE, - hf_to_cache_dir, - is_model_healthy, - parse_model_spec, - ) -except ImportError: - from pathlib import Path - def parse_model_spec(x): return (x, None) - def hf_to_cache_dir(x): return x - if "HF_HOME" in os.environ: - MODEL_CACHE = Path(os.environ["HF_HOME"]) / "hub" - else: - MODEL_CACHE = Path(os.path.expanduser("~/.cache/huggingface/hub")) - def is_model_healthy(x): return False - -def describe_http_exception(exc): - if hasattr(exc, "response") and exc.response is not None: - status = getattr(exc.response, "status_code", None) - url = getattr(exc.response, "url", None) - if status == 401: - return f"[ERROR] Unauthorized (401): Check your HuggingFace token or login.\nURL: {url}" - elif status == 403: - return f"[ERROR] Forbidden (403): Access denied.\nURL: {url}" - elif status == 404: - return f"[ERROR] Not Found (404): Resource does not exist.\nURL: {url}" - elif status >= 500: - return f"[ERROR] Server Error ({status}): Problem on HuggingFace's side.\nURL: {url}\nTry again later." - else: - return f"[ERROR] HTTP Error {status}: {exc}\nURL: {url}" - return f"[ERROR] HTTP Error: {exc}" - -def configure_download_environment(): - os.environ['HF_HUB_DOWNLOAD_THREADS'] = '1' - os.environ['HF_HUB_DOWNLOAD_CHUNK_SIZE'] = '524288' # 512KB chunks for household-friendly downloads - os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = 'false' - -def pull_model(model_spec): - original_spec = model_spec - model_name, commit_hash = parse_model_spec(model_spec) - - # Validate HuggingFace Hub repository name length limit (Issue #6) - if len(model_name) > 96: - print(f"[ERROR] Repository name exceeds HuggingFace Hub limit: {len(model_name)}/96 characters") - print("Repository names longer than 96 characters cannot exist on HuggingFace Hub.") - print(f"Invalid name: '{model_name}'") - return False - - if "/" not in original_spec.split("@")[0] and "/" in model_name: - confirm = input(f"Download '{model_name}'? [Y/n] ") - if confirm.lower() == "n": - print("Download cancelled.") - return - - base_cache_dir = MODEL_CACHE / hf_to_cache_dir(model_name) - if commit_hash: - hash_dir = base_cache_dir / "snapshots" / commit_hash - if hash_dir.exists() and is_model_healthy(f"{model_name}@{commit_hash}"): - print("Model already exists") - return - else: - if base_cache_dir.exists() and is_model_healthy(model_name): - print("Model already exists") - return - - print(f"Downloading {model_name}...") - - # Build kwargs dict for the worker - kwargs_dict = { - "repo_id": model_name, - "local_dir_use_symlinks": False, - "max_workers": 1 - } - if commit_hash: - kwargs_dict["revision"] = commit_hash - if "mlx-community" in model_name: - kwargs_dict["allow_patterns"] = [ - "*.json", "*.txt", "*.safetensors", "*.md", "*.gitattributes", "LICENSE" - ] - if "mlx-community" not in model_name: - confirm = input(f"[WARNING] {model_name} is not an MLX model (may be >1GB). Continue? [y/N] ") - if confirm.lower() != "y": - print("Download cancelled.") - return - - kwargs_str = json.dumps(kwargs_dict, indent=2) - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: - f.write(kwargs_str) - kwargs_file = f.name - - # Call the worker as subprocess with nice priority - worker_path = os.path.join(os.path.dirname(__file__), "throttled_download_worker.py") - try: - result = subprocess.run( - ['nice', '-n', '19', sys.executable, worker_path, kwargs_file], - check=False - ) - if result.returncode == 0: - print("Download completed successfully.") - elif result.returncode in (10, 11, 12, 13, 14, 15): - # Already handled in worker, do NOT retry fallback - print("[WARNING] Fatal error encountered in throttled download, not attempting fallback.") - return - else: - print("[WARNING] Throttled download failed or was interrupted.") - print("Attempting fallback download with standard throttling...") - try: - import requests - from huggingface_hub import snapshot_download - configure_download_environment() - snapshot_download(**kwargs_dict) - print("Download completed successfully.") - except requests.exceptions.HTTPError as e: - print(describe_http_exception(e)) - return - except requests.exceptions.ConnectionError: - print("[ERROR] Network connection error. Please check your internet connection and try again.") - return - except requests.exceptions.Timeout: - print("[ERROR] Download timed out. Please try again.") - return - except KeyboardInterrupt: - print("\n[WARNING] Download cancelled by user.") - return - except Exception as e: - print(f"[ERROR] Unexpected error during fallback download: {type(e).__name__}: {e}") - return - except KeyboardInterrupt: - print("\n[WARNING] Download cancelled by user.") - return - except ImportError: - print("huggingface-hub is not installed. Please install it with: pip install huggingface-hub") - except Exception as e: - print(f"[ERROR] Unexpected error: {type(e).__name__}: {e}") diff --git a/mlx_knife/mlx_runner.py b/mlx_knife/mlx_runner.py deleted file mode 100644 index e7201e6..0000000 --- a/mlx_knife/mlx_runner.py +++ /dev/null @@ -1,811 +0,0 @@ -# mlx_knife/mlx_runner.py -""" -Enhanced MLX model runner with direct API integration. -Provides ollama-like run experience with streaming and interactive chat. -""" - -import json -import os -import time -from collections.abc import Iterator -from pathlib import Path -from typing import Dict, Optional - -import mlx.core as mx -from mlx_lm import load -from mlx_lm.generate import generate_step -from mlx_lm.sample_utils import make_repetition_penalty, make_sampler - - -def get_model_context_length(model_path: str) -> int: - """Extract max_position_embeddings from model config. - - Args: - model_path: Path to the MLX model directory - - Returns: - Maximum context length for the model (defaults to 4096 if not found) - """ - config_path = os.path.join(model_path, "config.json") - - try: - with open(config_path) as f: - config = json.load(f) - - # Try various common config keys for context length - context_keys = [ - "max_position_embeddings", - "n_positions", - "context_length", - "max_sequence_length", - "seq_len" - ] - - for key in context_keys: - if key in config: - return config[key] - - # If no context length found, return reasonable default - return 4096 - - except (FileNotFoundError, json.JSONDecodeError, KeyError): - # Return default if config can't be read - return 4096 - - -class MLXRunner: - """Direct MLX model runner with streaming and interactive capabilities.""" - - def __init__(self, model_path: str, adapter_path: Optional[str] = None, verbose: bool = False): - """Initialize the runner with a model. - - Args: - model_path: Path to the MLX model directory - adapter_path: Optional path to LoRA adapter - verbose: Show detailed output - """ - self.model_path = Path(model_path) - self.adapter_path = adapter_path - self.model = None - self.tokenizer = None - self._memory_baseline = None - self._stop_tokens = None # Will be populated from tokenizer - self._chat_stop_tokens = None # Chat-specific stop tokens - self._context_length = None # Will be populated from model config - self.verbose = verbose - self._model_loaded = False - self._context_entered = False # Prevent nested context usage - - def __enter__(self): - """Context manager entry - loads the model.""" - if self._context_entered: - raise RuntimeError("MLXRunner context manager cannot be entered multiple times") - - self._context_entered = True - try: - self.load_model() - return self - except Exception: - # If load_model fails, ensure cleanup happens - self._context_entered = False - self.cleanup() - raise - - def __exit__(self, exc_type, exc_val, exc_tb): - """Context manager exit - cleans up the model.""" - self._context_entered = False - self.cleanup() - return False # Don't suppress exceptions - - def load_model(self): - """Load the MLX model and tokenizer.""" - if self._model_loaded: - if self.verbose: - print("Model already loaded, skipping...") - return - - if self.verbose: - print(f"Loading model from {self.model_path}...") - start_time = time.time() - - # Capture baseline memory before loading - try: - mx.clear_cache() - except Exception: - pass # Continue even if cache clear fails - self._memory_baseline = mx.get_active_memory() / 1024**3 - - try: - # Load model and tokenizer - self.model, self.tokenizer = load( - str(self.model_path), - adapter_path=self.adapter_path - ) - - load_time = time.time() - start_time - current_memory = mx.get_active_memory() / 1024**3 - model_memory = current_memory - self._memory_baseline - - if self.verbose: - print(f"Model loaded in {load_time:.1f}s") - print(f"Memory: {model_memory:.1f}GB model, {current_memory:.1f}GB total") - - # Extract stop tokens from tokenizer - self._extract_stop_tokens() - - # Extract context length from model config - self._context_length = get_model_context_length(str(self.model_path)) - - if self.verbose: - print(f"Model context length: {self._context_length} tokens") - - self._model_loaded = True - - except Exception as e: - # Ensure partial state is cleaned up on failure - self.model = None - self.tokenizer = None - self._stop_tokens = None - self._model_loaded = False - # Clear any memory that might have been allocated - mx.clear_cache() - raise RuntimeError(f"Failed to load model from {self.model_path}: {e}") from e - - def _extract_stop_tokens(self): - """Extract stop tokens from the tokenizer dynamically.""" - self._stop_tokens = set() - - # Primary source: eos_token - eos_token = getattr(self.tokenizer, 'eos_token', None) - if eos_token: - self._stop_tokens.add(eos_token) - - # Also check pad_token if it's different from eos_token - pad_token = getattr(self.tokenizer, 'pad_token', None) - if pad_token and pad_token != eos_token: - self._stop_tokens.add(pad_token) - - # Check additional_special_tokens - if hasattr(self.tokenizer, 'additional_special_tokens'): - for token in self.tokenizer.additional_special_tokens: - if token and isinstance(token, str): - # Only add tokens that look like stop/end tokens - if any(keyword in token.lower() for keyword in ['end', 'stop', 'eot']): - self._stop_tokens.add(token) - - # Add common stop tokens that might not be in special tokens - # but are frequently used across models - common_stop_tokens = {'', '<|endoftext|>', '<|im_end|>'} - - # Add chat-specific stop tokens to prevent model self-conversations - # Based on our _format_conversation() format: "Human:" and "Assistant:" - # Also include "You:" as models might use UI-visible format - # Include single-letter variations (H:, A:, Y:) that some models use - chat_stop_tokens = { - '\nHuman:', '\nAssistant:', '\nYou:', - '\n\nHuman:', '\n\nAssistant:', '\n\nYou:', - '\nH:', '\nA:', '\nY:', # Single-letter variations - '\n\nH:', '\n\nA:', '\n\nY:' - } - - # Add common stop tokens only if they decode to themselves (i.e., they're single tokens) - for token in common_stop_tokens: - try: - # Try to encode and decode to verify it's a real single token - ids = self.tokenizer.encode(token, add_special_tokens=False) - if ids and len(ids) == 1: # Single token ID means it's a special token - decoded = self.tokenizer.decode(ids) - if decoded == token: - self._stop_tokens.add(token) - except: - pass - - # Store chat stop tokens separately - only used in interactive chat mode - # This prevents stopping mid-story when user asks for dialogues - self._chat_stop_tokens = list(chat_stop_tokens) - - # Remove any None values - self._stop_tokens.discard(None) - - # Convert to list for easier use - self._stop_tokens = list(self._stop_tokens) - - if self._stop_tokens and self.verbose: - print(f"Stop tokens: {self._stop_tokens}") - - def cleanup(self): - """Clean up model resources and clear GPU memory. - - This method is safe to call multiple times and handles partial state cleanup. - """ - if self.verbose and self._model_loaded: - memory_before = mx.get_active_memory() / 1024**3 - print(f"Cleaning up model (memory before: {memory_before:.1f}GB)...") - - # Always clean up, even if model wasn't fully loaded - self.model = None - self.tokenizer = None - self._stop_tokens = None - self._chat_stop_tokens = None - self._context_length = None - self._model_loaded = False - - # Force garbage collection and clear MLX cache - import gc - gc.collect() - try: - mx.clear_cache() - except Exception: - pass # Continue cleanup even if cache clear fails - - if self.verbose: - memory_after = mx.get_active_memory() / 1024**3 - if 'memory_before' in locals(): - memory_freed = memory_before - memory_after - print(f"Cleanup complete (memory after: {memory_after:.1f}GB, freed: {memory_freed:.1f}GB)") - else: - print(f"Cleanup complete (memory after: {memory_after:.1f}GB)") - - def get_effective_max_tokens(self, requested_tokens: Optional[int], interactive: bool = False) -> int: - """Get effective max tokens based on model context and usage mode. - - Args: - requested_tokens: The requested max tokens (None if user didn't specify --max-tokens) - interactive: True if this is interactive mode (gets full context length) - - Returns: - Effective max tokens to use - """ - if not self._context_length: - # Fallback when context length is unknown - fallback = 4096 if interactive else 2048 - if self.verbose: - if requested_tokens is None: - print(f"[WARNING] Model context length unknown, using fallback: {fallback} tokens") - else: - print(f"[WARNING] Model context length unknown, using user specified: {requested_tokens} tokens") - return requested_tokens if requested_tokens is not None else fallback - - if interactive: - if requested_tokens is None: - # User didn't specify --max-tokens: use full model context - return self._context_length - else: - # User specified --max-tokens explicitly: respect their choice but cap at context - return min(requested_tokens, self._context_length) - else: - # Server/batch mode uses half context length for DoS protection - server_limit = self._context_length // 2 - return min(requested_tokens or server_limit, server_limit) - - def generate_streaming( - self, - prompt: str, - max_tokens: int = 500, - temperature: float = 0.7, - top_p: float = 0.9, - repetition_penalty: float = 1.1, - repetition_context_size: int = 20, - use_chat_template: bool = True, - use_chat_stop_tokens: bool = False, - interactive: bool = False, - ) -> Iterator[str]: - """Generate text with streaming output. - - Args: - prompt: Input prompt - max_tokens: Maximum tokens to generate - temperature: Sampling temperature - top_p: Top-p sampling parameter - repetition_penalty: Penalty for repeated tokens - repetition_context_size: Context size for repetition penalty - use_chat_template: Apply tokenizer's chat template if available - use_chat_stop_tokens: Include chat turn markers as stop tokens (for interactive mode) - interactive: True if this is interactive mode (affects token limits) - - Yields: - Generated tokens as they are produced - """ - if not self.model or not self.tokenizer: - raise RuntimeError("Model not loaded. Call load_model() first.") - - # Apply context-aware token limits - effective_max_tokens = self.get_effective_max_tokens(max_tokens, interactive) - - # Apply chat template if available and requested - if use_chat_template and hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template: - messages = [{"role": "user", "content": prompt}] - formatted_prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - else: - formatted_prompt = prompt - - # Tokenize the prompt - prompt_tokens = self.tokenizer.encode(formatted_prompt) - prompt_array = mx.array(prompt_tokens) - - # Track generation metrics - start_time = time.time() - tokens_generated = 0 - - # Create sampler with our parameters - sampler = make_sampler(temp=temperature, top_p=top_p) - - # Create repetition penalty processor if needed - logits_processors = [] - if repetition_penalty > 1.0: - logits_processors.append( - make_repetition_penalty(repetition_penalty, repetition_context_size) - ) - - # Generate tokens one by one for streaming - generator = generate_step( - prompt=prompt_array, - model=self.model, - max_tokens=effective_max_tokens, - sampler=sampler, - logits_processors=logits_processors if logits_processors else None, - ) - - # Collect tokens and yield text - generated_tokens = [] - previous_decoded = "" - accumulated_response = "" # Track full response for stop token detection - - # Keep a sliding window of recent tokens for context - context_window = 10 # Decode last N tokens for proper spacing - - for token, _ in generator: - # Token might be an array or an int - token_id = token.item() if hasattr(token, 'item') else token - generated_tokens.append(token_id) - - # Use a sliding window approach for efficiency - start_idx = max(0, len(generated_tokens) - context_window) - window_tokens = generated_tokens[start_idx:] - - # Decode the window - window_text = self.tokenizer.decode(window_tokens) - - # Figure out what's new - if start_idx == 0: - # We're still within the context window - if window_text.startswith(previous_decoded): - new_text = window_text[len(previous_decoded):] - else: - new_text = self.tokenizer.decode([token_id]) - previous_decoded = window_text - else: - # We're beyond the context window, just decode the last token with context - # This is approximate but should preserve spaces - new_text = self.tokenizer.decode(window_tokens) - if len(window_tokens) > 1: - prefix = self.tokenizer.decode(window_tokens[:-1]) - if new_text.startswith(prefix): - new_text = new_text[len(prefix):] - else: - new_text = self.tokenizer.decode([token_id]) - - if new_text: - # Update accumulated response for stop token checking - accumulated_response += new_text - - # Filter out stop tokens with priority: native first, then chat fallback - # Check native stop tokens FIRST in accumulated response (highest priority) - native_stop_tokens = self._stop_tokens if self._stop_tokens else [] - for stop_token in native_stop_tokens: - if stop_token in accumulated_response: - # Find the stop token position and yield everything before it - stop_pos = accumulated_response.find(stop_token) - # Calculate what text came before the stop token - text_before_stop = accumulated_response[:stop_pos] - # Calculate how much of that is new (not previously yielded) - previously_yielded_length = len(accumulated_response) - len(new_text) - if len(text_before_stop) > previously_yielded_length: - # Yield only the new part before stop token - new_part_before_stop = text_before_stop[previously_yielded_length:] - if new_part_before_stop: - yield new_part_before_stop - return # Stop generation without yielding stop token - - # Only check chat stop tokens if no native stop token found (fallback) - if use_chat_stop_tokens and self._chat_stop_tokens: - for stop_token in self._chat_stop_tokens: - if stop_token in accumulated_response: - # Find the stop token position and yield everything before it - stop_pos = accumulated_response.find(stop_token) - # Calculate what text came before the stop token - text_before_stop = accumulated_response[:stop_pos] - # Calculate how much of that is new (not previously yielded) - previously_yielded_length = len(accumulated_response) - len(new_text) - if len(text_before_stop) > previously_yielded_length: - # Yield only the new part before stop token - new_part_before_stop = text_before_stop[previously_yielded_length:] - if new_part_before_stop: - yield new_part_before_stop - return # Stop generation without yielding stop token - - # No stop token found, yield the new text - yield new_text - tokens_generated += 1 - - # Check for EOS token - don't yield it - if token_id == self.tokenizer.eos_token_id: - break - - # Print generation statistics if verbose - if self.verbose: - generation_time = time.time() - start_time - tokens_per_second = tokens_generated / generation_time if generation_time > 0 else 0 - print(f"\n\nGenerated {tokens_generated} tokens in {generation_time:.1f}s ({tokens_per_second:.1f} tokens/s)") - - def generate_batch( - self, - prompt: str, - max_tokens: int = 500, - temperature: float = 0.7, - top_p: float = 0.9, - repetition_penalty: float = 1.1, - repetition_context_size: int = 20, - use_chat_template: bool = True, - interactive: bool = False, - ) -> str: - """Generate text in batch mode (non-streaming). - - Args: - prompt: Input prompt - max_tokens: Maximum tokens to generate - temperature: Sampling temperature - top_p: Top-p sampling parameter - repetition_penalty: Penalty for repeated tokens - repetition_context_size: Context size for repetition penalty - use_chat_template: Apply tokenizer's chat template if available - interactive: True if this is interactive mode (affects token limits) - - Returns: - Generated text - """ - if not self.model or not self.tokenizer: - raise RuntimeError("Model not loaded. Call load_model() first.") - - # Apply context-aware token limits - effective_max_tokens = self.get_effective_max_tokens(max_tokens, interactive) - - # Apply chat template if available and requested - if use_chat_template and hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template: - messages = [{"role": "user", "content": prompt}] - formatted_prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) - else: - formatted_prompt = prompt - - start_time = time.time() - - # Tokenize the prompt - prompt_tokens = self.tokenizer.encode(formatted_prompt) - prompt_array = mx.array(prompt_tokens) - - # Create sampler with our parameters - sampler = make_sampler(temp=temperature, top_p=top_p) - - # Create repetition penalty processor if needed - logits_processors = [] - if repetition_penalty > 1.0: - logits_processors.append( - make_repetition_penalty(repetition_penalty, repetition_context_size) - ) - - # Generate all tokens at once - generated_tokens = [] - all_tokens = list(prompt_tokens) # Keep prompt for proper decoding - - generator = generate_step( - prompt=prompt_array, - model=self.model, - max_tokens=effective_max_tokens, - sampler=sampler, - logits_processors=logits_processors if logits_processors else None, - ) - - for token, _ in generator: - # Token might be an array or an int - token_id = token.item() if hasattr(token, 'item') else token - generated_tokens.append(token_id) - all_tokens.append(token_id) - - # Check for EOS token - don't yield it - if token_id == self.tokenizer.eos_token_id: - break - - # Decode all tokens together for proper spacing - full_response = self.tokenizer.decode(all_tokens) - - # Remove the prompt part - if full_response.startswith(formatted_prompt): - response = full_response[len(formatted_prompt):] - else: - # Fallback: just decode generated tokens - response = self.tokenizer.decode(generated_tokens) - - # Apply end-token filtering (same logic as streaming mode for Issue #20) - response = self._filter_end_tokens_from_response(response, use_chat_stop_tokens=False) - - generation_time = time.time() - start_time - - # Count tokens for statistics - if self.verbose: - tokens_generated = len(generated_tokens) - tokens_per_second = tokens_generated / generation_time if generation_time > 0 else 0 - print(f"\nGenerated {tokens_generated} tokens in {generation_time:.1f}s ({tokens_per_second:.1f} tokens/s)") - - return response - - def interactive_chat( - self, - system_prompt: Optional[str] = None, - max_tokens: int = 500, - temperature: float = 0.7, - top_p: float = 0.9, - repetition_penalty: float = 1.1, - ): - """Run an interactive chat session. - - Args: - system_prompt: Optional system prompt to prepend - max_tokens: Maximum tokens per response - temperature: Sampling temperature - top_p: Top-p sampling parameter - repetition_penalty: Penalty for repeated tokens - """ - print("Starting interactive chat. Type 'exit' or 'quit' to end.\n") - - conversation_history = [] - if system_prompt: - conversation_history.append({"role": "system", "content": system_prompt}) - - while True: - try: - # Get user input - user_input = input("You: ").strip() - - if user_input.lower() in ['exit', 'quit', 'q']: - print("\nGoodbye!") - break - - if not user_input: - continue - - # Add user message to history - conversation_history.append({"role": "user", "content": user_input}) - - # Format conversation for the model - # This is a simple format - models may need specific chat templates - prompt = self._format_conversation(conversation_history) - - # Generate response with streaming - print("\nAssistant: ", end="", flush=True) - - response_tokens = [] - for token in self.generate_streaming( - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - use_chat_stop_tokens=True, # Enable chat stop tokens in interactive mode - interactive=True, # Enable full context length for interactive mode - ): - print(token, end="", flush=True) - response_tokens.append(token) - - # Add assistant response to history - assistant_response = "".join(response_tokens).strip() - conversation_history.append({"role": "assistant", "content": assistant_response}) - - print() # New line after response - - except KeyboardInterrupt: - print("\n\nChat interrupted. Goodbye!") - break - except Exception as e: - print(f"\n[ERROR] {e}") - continue - - def _format_conversation(self, messages: list) -> str: - """Format conversation history into a prompt. - - This is a simple format. Different models may need different templates. - """ - formatted = [] - - for message in messages: - role = message["role"] - content = message["content"] - - if role == "system": - formatted.append(f"System: {content}") - elif role == "user": - formatted.append(f"Human: {content}") - elif role == "assistant": - formatted.append(f"Assistant: {content}") - - # Add prompt for next assistant response - formatted.append("Assistant:") - - return "\n\n".join(formatted) - - def get_memory_usage(self) -> Dict[str, float]: - """Get current memory usage statistics. - - Returns: - Dictionary with memory statistics in GB - """ - try: - current_memory = mx.get_active_memory() / 1024**3 - peak_memory = mx.get_peak_memory() / 1024**3 - except Exception: - # Return zeros if memory stats unavailable - current_memory = 0.0 - peak_memory = 0.0 - - return { - "current_gb": current_memory, - "peak_gb": peak_memory, - "model_gb": current_memory - self._memory_baseline if self._memory_baseline else 0, - } - - def _filter_end_tokens_from_response(self, response: str, use_chat_stop_tokens: bool = False) -> str: - """Filter end tokens from a complete response (batch mode). - - This method applies the same filtering logic as the streaming mode - to ensure consistent behavior between streaming and non-streaming. - - Args: - response: The complete generated response - use_chat_stop_tokens: Whether to apply chat stop tokens - - Returns: - Response with end tokens filtered out - """ - # Apply native stop token filtering FIRST (highest priority) - native_stop_tokens = self._stop_tokens if self._stop_tokens else [] - for stop_token in native_stop_tokens: - if stop_token in response: - # Find the stop token position and return everything before it - stop_pos = response.find(stop_token) - return response[:stop_pos] - - # Only check chat stop tokens if no native stop token found (fallback) - if use_chat_stop_tokens and self._chat_stop_tokens: - for stop_token in self._chat_stop_tokens: - if stop_token in response: - # Find the stop token position and return everything before it - stop_pos = response.find(stop_token) - return response[:stop_pos] - - # No stop tokens found, return original response - return response - - -def get_gpu_status() -> Dict[str, float]: - """Independent GPU status check - usable from anywhere. - - Returns: - Dictionary with GPU memory statistics in GB - """ - return { - "active_memory_gb": mx.get_active_memory() / 1024**3, - "peak_memory_gb": mx.get_peak_memory() / 1024**3, - } - - -def check_memory_available(required_gb: float) -> bool: - """Pre-flight check before model loading. - - Args: - required_gb: Required memory in GB - - Returns: - True if memory is likely available (conservative estimate) - """ - current_memory = mx.get_active_memory() / 1024**3 - - # Conservative estimate: assume system has at least 8GB unified memory - # and we should leave some headroom (2GB) for system processes - estimated_total = 8.0 # This could be improved by detecting actual system memory - available = estimated_total - current_memory - 2.0 # 2GB headroom - - return available >= required_gb - - -def run_model_enhanced( - model_path: str, - prompt: Optional[str] = None, - interactive: bool = False, - max_tokens: int = 500, - temperature: float = 0.7, - top_p: float = 0.9, - repetition_penalty: float = 1.1, - stream: bool = True, - use_chat_template: bool = True, - verbose: bool = False, -) -> Optional[str]: - """Enhanced run function with direct MLX integration. - - Uses context manager pattern for automatic resource cleanup. - - Args: - model_path: Path to the MLX model - prompt: Input prompt (if None, enters interactive mode) - interactive: Force interactive mode - max_tokens: Maximum tokens to generate - temperature: Sampling temperature - top_p: Top-p sampling parameter - repetition_penalty: Penalty for repeated tokens - stream: Whether to stream output - - Returns: - Generated text (in non-interactive mode) - """ - try: - with MLXRunner(model_path, verbose=verbose) as runner: - # Interactive mode - if interactive or prompt is None: - runner.interactive_chat( - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - ) - return None - - # Single prompt mode - if verbose: - print(f"\nPrompt: {prompt}\n") - print("Response: ", end="", flush=True) - - if stream: - # Streaming generation - response_tokens = [] - for token in runner.generate_streaming( - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - use_chat_template=use_chat_template, - ): - print(token, end="", flush=True) - response_tokens.append(token) - - response = "".join(response_tokens) - else: - # Batch generation - response = runner.generate_batch( - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - use_chat_template=use_chat_template, - ) - print(response) - - # Show memory usage if verbose - if verbose: - memory_stats = runner.get_memory_usage() - print(f"\n\nMemory: {memory_stats['model_gb']:.1f}GB model, {memory_stats['current_gb']:.1f}GB total") - - return response - - # Note: cleanup happens automatically due to context manager - - except Exception as e: - print(f"\n[ERROR] {e}") - return None diff --git a/mlx_knife/server.py b/mlx_knife/server.py deleted file mode 100644 index f0e8810..0000000 --- a/mlx_knife/server.py +++ /dev/null @@ -1,555 +0,0 @@ -# mlx_knife/server.py -""" -OpenAI-compatible API server for MLX models. -Provides REST endpoints for text generation with MLX backend. -""" - -import json -import time -import uuid -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager -from typing import Any, Dict, List, Optional, Union - -import uvicorn -from fastapi import FastAPI, HTTPException -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -from pydantic import BaseModel, Field - -from .cache_utils import detect_framework, is_model_healthy -from .mlx_runner import MLXRunner - -# Global model cache and configuration -_model_cache: Dict[str, MLXRunner] = {} -_current_model_path: Optional[str] = None -_default_max_tokens: Optional[int] = None # Use dynamic model-aware limits by default - - -class CompletionRequest(BaseModel): - model: str - prompt: Union[str, List[str]] - max_tokens: Optional[int] = None - temperature: Optional[float] = 0.7 - top_p: Optional[float] = 0.9 - stream: Optional[bool] = False - stop: Optional[Union[str, List[str]]] = None - repetition_penalty: Optional[float] = 1.1 - - -class ChatMessage(BaseModel): - role: str = Field(..., pattern="^(system|user|assistant)$") - content: str - - -class ChatCompletionRequest(BaseModel): - model: str - messages: List[ChatMessage] - max_tokens: Optional[int] = None - temperature: Optional[float] = 0.7 - top_p: Optional[float] = 0.9 - stream: Optional[bool] = False - stop: Optional[Union[str, List[str]]] = None - repetition_penalty: Optional[float] = 1.1 - - -class CompletionResponse(BaseModel): - id: str - object: str = "text_completion" - created: int - model: str - choices: List[Dict[str, Any]] - usage: Dict[str, int] - - -class ChatCompletionResponse(BaseModel): - id: str - object: str = "chat.completion" - created: int - model: str - choices: List[Dict[str, Any]] - usage: Dict[str, int] - - -class ModelInfo(BaseModel): - id: str - object: str = "model" - owned_by: str = "mlx-knife" - permission: List = [] - context_length: Optional[int] = None - - - -def get_or_load_model(model_spec: str, verbose: bool = False) -> MLXRunner: - """Get model from cache or load it if not cached.""" - global _model_cache, _current_model_path - - # Use the existing model path resolution from cache_utils - from .cache_utils import get_model_path - - try: - model_path, model_name, commit_hash = get_model_path(model_spec) - if not model_path.exists(): - raise HTTPException(status_code=404, detail=f"Model {model_spec} not found in cache") - except Exception as e: - raise HTTPException(status_code=404, detail=f"Model {model_spec} not found: {str(e)}") - - # Check if it's an MLX model - framework = detect_framework(model_path.parent.parent, model_name) - if framework != "MLX": - raise HTTPException(status_code=400, detail=f"Model {model_name} is not a valid MLX model (Framework: {framework})") - - model_path_str = str(model_path) - - # Check if we need to load a different model - if _current_model_path != model_path_str: - # Clear cache if switching models to avoid memory issues - _model_cache.clear() - - # Load new model - if verbose: - print(f"Loading model: {model_name}") - - runner = MLXRunner(model_path_str, verbose=verbose) - runner.load_model() - - _model_cache[model_path_str] = runner - _current_model_path = model_path_str - - return _model_cache[model_path_str] - - -async def generate_completion_stream( - runner: MLXRunner, - prompt: str, - request: CompletionRequest -) -> AsyncGenerator[str, None]: - """Generate streaming completion response.""" - completion_id = f"cmpl-{uuid.uuid4()}" - created = int(time.time()) - - # Yield initial response - initial_response = { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": request.model, - "choices": [ - { - "index": 0, - "text": "", - "logprobs": None, - "finish_reason": None - } - ] - } - - yield f"data: {json.dumps(initial_response)}\n\n" - - # Stream tokens - try: - token_count = 0 - for token in runner.generate_streaming( - prompt=prompt, - max_tokens=runner.get_effective_max_tokens(request.max_tokens or _default_max_tokens, interactive=False), - temperature=request.temperature, - top_p=request.top_p, - repetition_penalty=request.repetition_penalty, - use_chat_template=False # Raw completion mode - ): - token_count += 1 - - chunk_response = { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": request.model, - "choices": [ - { - "index": 0, - "text": token, - "logprobs": None, - "finish_reason": None - } - ] - } - - yield f"data: {json.dumps(chunk_response)}\n\n" - - # Check for stop sequences - if request.stop: - stop_sequences = request.stop if isinstance(request.stop, list) else [request.stop] - if any(stop in token for stop in stop_sequences): - break - - except Exception as e: - error_response = { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": request.model, - "choices": [ - { - "index": 0, - "text": "", - "logprobs": None, - "finish_reason": "error" - } - ], - "error": str(e) - } - yield f"data: {json.dumps(error_response)}\n\n" - - # Final response - final_response = { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": request.model, - "choices": [ - { - "index": 0, - "text": "", - "logprobs": None, - "finish_reason": "stop" - } - ] - } - - yield f"data: {json.dumps(final_response)}\n\n" - yield "data: [DONE]\n\n" - - -async def generate_chat_stream( - runner: MLXRunner, - messages: List[ChatMessage], - request: ChatCompletionRequest -) -> AsyncGenerator[str, None]: - """Generate streaming chat completion response.""" - completion_id = f"chatcmpl-{uuid.uuid4()}" - created = int(time.time()) - - # Convert messages to prompt - prompt = format_chat_messages(messages) - - # Yield initial response - initial_response = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created, - "model": request.model, - "choices": [ - { - "index": 0, - "delta": {"role": "assistant", "content": ""}, - "finish_reason": None - } - ] - } - - yield f"data: {json.dumps(initial_response)}\n\n" - - # Stream tokens - try: - for token in runner.generate_streaming( - prompt=prompt, - max_tokens=runner.get_effective_max_tokens(request.max_tokens or _default_max_tokens, interactive=False), - temperature=request.temperature, - top_p=request.top_p, - repetition_penalty=request.repetition_penalty, - use_chat_template=True - ): - chunk_response = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created, - "model": request.model, - "choices": [ - { - "index": 0, - "delta": {"content": token}, - "finish_reason": None - } - ] - } - - yield f"data: {json.dumps(chunk_response)}\n\n" - - # Check for stop sequences - if request.stop: - stop_sequences = request.stop if isinstance(request.stop, list) else [request.stop] - if any(stop in token for stop in stop_sequences): - break - - except Exception as e: - error_response = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created, - "model": request.model, - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "error" - } - ], - "error": str(e) - } - yield f"data: {json.dumps(error_response)}\n\n" - - # Final response - final_response = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created, - "model": request.model, - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop" - } - ] - } - - yield f"data: {json.dumps(final_response)}\n\n" - yield "data: [DONE]\n\n" - - -def format_chat_messages(messages: List[ChatMessage]) -> str: - """Convert chat messages to a prompt string.""" - # Simple format - models with chat templates will format properly - formatted = [] - for message in messages: - if message.role == "system": - formatted.append(f"System: {message.content}") - elif message.role == "user": - formatted.append(f"Human: {message.content}") - elif message.role == "assistant": - formatted.append(f"Assistant: {message.content}") - - return "\n\n".join(formatted) - - -def count_tokens(text: str) -> int: - """Rough token count estimation.""" - return int(len(text.split()) * 1.3) # Approximation, convert to int - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Manage application lifespan.""" - print("MLX Knife Server starting up...") - yield - print("MLX Knife Server shutting down...") - # Clean up model cache - global _model_cache - _model_cache.clear() - - -# Create FastAPI app -from . import __version__ - -app = FastAPI( - title="MLX Knife API", - description="OpenAI-compatible API for MLX models", - version=__version__, - lifespan=lifespan -) - -# Add CORS middleware for browser access -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # Allow all origins for local development - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -@app.get("/health") -async def health_check(): - """Health check endpoint (OpenAI compatible).""" - return {"status": "healthy", "service": "mlx-knife-server"} - - - - -@app.get("/v1/models") -async def list_models(): - """List available models.""" - from .cache_utils import MODEL_CACHE, cache_dir_to_hf - - model_list = [] - models = [d for d in MODEL_CACHE.iterdir() if d.name.startswith("models--")] - - for model_dir in models: - model_name = cache_dir_to_hf(model_dir.name) - framework = detect_framework(model_dir, model_name) - - if framework == "MLX" and is_model_healthy(model_name): - # Get model context length - context_length = None - try: - from .cache_utils import get_model_path - from .mlx_runner import get_model_context_length - model_path_tuple = get_model_path(model_name) - if model_path_tuple and model_path_tuple[0]: - context_length = get_model_context_length(str(model_path_tuple[0])) - except Exception: - pass # Fallback to None if context length cannot be determined - - model_list.append(ModelInfo( - id=model_name, - object="model", - owned_by="mlx-knife", - context_length=context_length - )) - - return {"object": "list", "data": model_list} - - -@app.post("/v1/completions") -async def create_completion(request: CompletionRequest): - """Create a text completion.""" - try: - runner = get_or_load_model(request.model) - - # Handle array of prompts - if isinstance(request.prompt, list): - if len(request.prompt) > 1: - raise HTTPException(status_code=400, detail="Multiple prompts not supported yet") - prompt = request.prompt[0] - else: - prompt = request.prompt - - if request.stream: - # Streaming response - return StreamingResponse( - generate_completion_stream(runner, prompt, request), - media_type="text/plain", - headers={"Cache-Control": "no-cache"} - ) - else: - # Non-streaming response - completion_id = f"cmpl-{uuid.uuid4()}" - created = int(time.time()) - - generated_text = runner.generate_batch( - prompt=prompt, - max_tokens=runner.get_effective_max_tokens(request.max_tokens or _default_max_tokens, interactive=False), - temperature=request.temperature, - top_p=request.top_p, - repetition_penalty=request.repetition_penalty, - use_chat_template=False - ) - - prompt_tokens = count_tokens(prompt) - completion_tokens = count_tokens(generated_text) - - return CompletionResponse( - id=completion_id, - created=created, - model=request.model, - choices=[ - { - "index": 0, - "text": generated_text, - "logprobs": None, - "finish_reason": "stop" - } - ], - usage={ - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens - } - ) - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.post("/v1/chat/completions") -async def create_chat_completion(request: ChatCompletionRequest): - """Create a chat completion.""" - try: - runner = get_or_load_model(request.model) - - if request.stream: - # Streaming response - return StreamingResponse( - generate_chat_stream(runner, request.messages, request), - media_type="text/plain", - headers={"Cache-Control": "no-cache"} - ) - else: - # Non-streaming response - completion_id = f"chatcmpl-{uuid.uuid4()}" - created = int(time.time()) - - # Format messages to prompt - prompt = format_chat_messages(request.messages) - - generated_text = runner.generate_batch( - prompt=prompt, - max_tokens=runner.get_effective_max_tokens(request.max_tokens or _default_max_tokens, interactive=False), - temperature=request.temperature, - top_p=request.top_p, - repetition_penalty=request.repetition_penalty, - use_chat_template=True - ) - - # Token counting - total_prompt = "\n\n".join([msg.content for msg in request.messages]) - prompt_tokens = count_tokens(total_prompt) - completion_tokens = count_tokens(generated_text) - - return ChatCompletionResponse( - id=completion_id, - created=created, - model=request.model, - choices=[ - { - "index": 0, - "message": { - "role": "assistant", - "content": generated_text - }, - "finish_reason": "stop" - } - ], - usage={ - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens - } - ) - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -def run_server( - host: str = "127.0.0.1", - port: int = 8000, - max_tokens: int = 2000, - reload: bool = False, - log_level: str = "info" -): - """Run the MLX Knife server.""" - global _default_max_tokens - _default_max_tokens = max_tokens - - print(f"Starting MLX Knife Server on http://{host}:{port}") - print(f"API docs available at http://{host}:{port}/docs") - print(f"Default max tokens: {'model-aware dynamic limits' if max_tokens is None else max_tokens}") - - uvicorn.run( - "mlx_knife.server:app", - host=host, - port=port, - reload=reload, - log_level=log_level - ) diff --git a/mlx_knife/throttled_download_worker.py b/mlx_knife/throttled_download_worker.py deleted file mode 100644 index 50b5b6e..0000000 --- a/mlx_knife/throttled_download_worker.py +++ /dev/null @@ -1,162 +0,0 @@ -import json -import os -import signal -import sys -import time -from typing import Any - -# Global tracking for accurate download rate -_download_stats = { - 'bytes_downloaded': 0, - 'start_time': None, - 'last_update': None, - 'actual_download_time': 0.0 # Time spent actually downloading (without delays) -} - - -def signal_handler(signum: int, frame: Any) -> None: - print("\n[WARNING] Download cancelled by user.") - sys.exit(0) - -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGTERM, signal_handler) - -os.environ["HF_HUB_DOWNLOAD_THREADS"] = "1" -os.environ["HF_HUB_DOWNLOAD_CHUNK_SIZE"] = "524288" # 512KB chunks (half size) -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "false" - -try: - import requests - from huggingface_hub import snapshot_download -except ImportError: - print("[ERROR] huggingface_hub or requests not installed in worker environment!") - sys.exit(2) - -# Throttle all HTTP(S) requests with adaptive delays -original_get = requests.get -original_post = requests.post - -def get_adaptive_delay(url: str, response: Any) -> float: - """Calculate delay based on file type and size""" - if not url: - return 1.0 - - # Check if this is a large model file download - if any(ext in url.lower() for ext in ['.safetensors', '.bin', '.pth']): - # For large model files, use more aggressive throttling - content_length = response.headers.get('content-length') - if content_length: - size_mb = int(content_length) / (1024 * 1024) - if size_mb > 100: # Files larger than 100MB - return 3.0 # 3 second delay between chunks - elif size_mb > 10: # Files larger than 10MB - return 2.0 # 2 second delay - return 2.0 # Default for model files - - # Regular files (config.json, tokenizer files, etc.) - return 0.5 - -def throttled_get(*args: Any, **kwargs: Any) -> Any: - download_start = time.time() - response = original_get(*args, **kwargs) - download_end = time.time() - - # Track actual download time (without delays) - actual_download_time = download_end - download_start - _download_stats['actual_download_time'] += actual_download_time - - # Track bytes if we can determine them - url = args[0] if args else kwargs.get('url', '') - if hasattr(response, 'headers') and 'content-length' in response.headers: - content_length = int(response.headers['content-length']) - _download_stats['bytes_downloaded'] += content_length - - # Initialize timing if first download - if _download_stats['start_time'] is None: - _download_stats['start_time'] = download_start - - # Print accurate rate every ~5MB or every 10 seconds - now = time.time() - if (_download_stats['last_update'] is None or - now - _download_stats['last_update'] > 10 or - _download_stats['bytes_downloaded'] % (5 * 1024 * 1024) < content_length): - - if _download_stats['actual_download_time'] > 0: - real_rate_mbps = (_download_stats['bytes_downloaded'] / _download_stats['actual_download_time']) / (1024 * 1024) - total_mb = _download_stats['bytes_downloaded'] / (1024 * 1024) - print(f"[THROTTLE] Downloaded {total_mb:.1f}MB at real rate: {real_rate_mbps:.1f}MB/s (excluding delays)") - _download_stats['last_update'] = now - - delay = get_adaptive_delay(url, response) - time.sleep(delay) - return response - -def throttled_post(*args: Any, **kwargs: Any) -> Any: - response = original_post(*args, **kwargs) - time.sleep(0.5) - return response - -requests.get = throttled_get -requests.post = throttled_post - -def main() -> None: - if len(sys.argv) != 2: - print("Usage: python throttled_download_worker.py ") - sys.exit(1) - - kwargs_file = sys.argv[1] - try: - with open(kwargs_file) as f: - kwargs_dict = json.load(f) - except Exception as e: - print(f"[ERROR] Could not read worker kwargs: {e}") - sys.exit(1) - - try: - snapshot_download(**kwargs_dict) - except requests.exceptions.HTTPError as e: - status = getattr(e.response, "status_code", None) - url = getattr(e.response, "url", None) - if status == 401: - print(f"[ERROR] Unauthorized (401): Check your HuggingFace token or login.\nURL: {url}") - sys.exit(10) - elif status == 403: - print(f"[ERROR] Forbidden (403): Access denied.\nURL: {url}") - sys.exit(11) - elif status == 404: - print(f"[ERROR] Not Found (404): Resource does not exist.\nURL: {url}") - sys.exit(12) - else: - print(f"[ERROR] HTTP Error: {e}") - sys.exit(2) - except requests.exceptions.ConnectionError: - print("[ERROR] Network connection error. Please check your internet connection and try again.") - sys.exit(20) - except PermissionError as e: - print(f"[ERROR] Permission denied: {e.filename if hasattr(e, 'filename') else 'check file permissions'}") - print(" Ensure you have write access to the cache directory.") - sys.exit(13) - except OSError as e: - import errno - if e.errno == errno.ENOSPC: - print("[ERROR] No space left on device. Please free up disk space and try again.") - sys.exit(14) - elif e.errno == errno.EACCES: - print(f"[ERROR] Access denied: {e.filename if hasattr(e, 'filename') else 'check permissions'}") - sys.exit(13) - else: - print(f"[ERROR] OS Error during download: {e}") - sys.exit(15) - except Exception as e: - print(f"[ERROR] Unexpected error during download: {type(e).__name__}: {e}") - sys.exit(2) - finally: - try: - os.unlink(kwargs_file) - except Exception: - pass - - sys.exit(0) - -if __name__ == "__main__": - main() diff --git a/mlxk2/NOTICE b/mlxk2/NOTICE new file mode 100644 index 0000000..61d3ff0 --- /dev/null +++ b/mlxk2/NOTICE @@ -0,0 +1,5 @@ +MLX-Knife 2.0 (mlxk2) +Copyright 2025 The BROKE team + +This product includes software developed by The BROKE team. +Licensed under the Apache License, Version 2.0. diff --git a/mlxk2/__init__.py b/mlxk2/__init__.py index 39e1ae7..2fa9a1d 100644 --- a/mlxk2/__init__.py +++ b/mlxk2/__init__.py @@ -7,4 +7,4 @@ import warnings # Issue parity with 1.1.0 (Issue #22) warnings.filterwarnings('ignore', message='urllib3 v2 only supports OpenSSL 1.1.1+') -__version__ = "2.0.0-alpha.2" +__version__ = "2.0.0-alpha.3" diff --git a/mlxk2/cli.py b/mlxk2/cli.py index 706e0ff..6d8f2dc 100644 --- a/mlxk2/cli.py +++ b/mlxk2/cli.py @@ -106,7 +106,7 @@ def main(): push_parser = subparsers.add_parser("push", help="EXPERIMENTAL: Upload a local folder to Hugging Face") push_parser.add_argument("local_dir", help="Local folder to upload") push_parser.add_argument("repo_id", help="Target repo as org/model") - push_parser.add_argument("--create", action="store_true", help="Create repository if missing") + push_parser.add_argument("--create", action="store_true", help="Create repository/branch if missing") # Alpha.1 safety: require --private to avoid accidental public uploads push_parser.add_argument( "--private", diff --git a/mlxk2/operations/common.py b/mlxk2/operations/common.py new file mode 100644 index 0000000..10bf6de --- /dev/null +++ b/mlxk2/operations/common.py @@ -0,0 +1,270 @@ +"""Common helpers for model metadata detection (2.0). + +Lenient framework/type detection for Issue #31 port: +- Prefer MLX for mlx-community/* or when README front-matter indicates MLX. +- Detect chat type via name, config, or tokenizer chat_template hints. + +Parsing is intentionally lightweight (no YAML dependency). Front-matter is +parsed from the first '---' block in README.md when present. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional +import json as _json + + +@dataclass +class FrontMatter: + tags: list[str] + library_name: Optional[str] + + +def read_front_matter(root: Path) -> Optional[FrontMatter]: + """Best-effort parse of README.md YAML-like front matter. + + Supports: + - Inline list: tags: [mlx, chat] + - Block list: + tags: + - mlx + - chat + - library_name: mlx + Returns None if README.md or front-matter block missing. + """ + try: + readme = root / "README.md" + if not readme.exists() or not readme.is_file(): + return None + lines = readme.read_text(encoding="utf-8", errors="ignore").splitlines() + if not lines or lines[0].strip() != "---": + return None + # Extract the first front-matter block + block: list[str] = [] + for line in lines[1:]: + if line.strip() == "---": + break + block.append(line.rstrip("\n")) + if not block: + return None + + tags: list[str] = [] + library_name: Optional[str] = None + + # Simple state machine for tags block list + in_tags_block = False + for raw in block: + s = raw.strip() + if not s: + continue + # library_name: value + if s.lower().startswith("library_name:"): + try: + library_name = s.split(":", 1)[1].strip().strip('"\'') + except Exception: + pass + in_tags_block = False + continue + + # tags: [a, b] + if s.lower().startswith("tags:") and "[" in s and "]" in s: + try: + inside = s.split("[", 1)[1].rsplit("]", 1)[0] + parts = [p.strip().strip('"\'') for p in inside.split(",") if p.strip()] + tags.extend([p for p in parts if p]) + except Exception: + pass + in_tags_block = False + continue + + # tags: (start of block list) + if s.lower().startswith("tags:"): + in_tags_block = True + continue + + if in_tags_block: + # Expect lines like "- mlx" + try: + if s.startswith("-"): + val = s.lstrip("-").strip().strip('"\'') + if val: + tags.append(val) + else: + # Any other non-dash line ends the block + in_tags_block = False + except Exception: + pass + + return FrontMatter(tags=tags, library_name=library_name) + except Exception: + return None + + +def read_tokenizer_hints(root: Path) -> Dict[str, Any]: + """Extract lightweight tokenizer hints (e.g., chat_template presence).""" + hints: Dict[str, Any] = {"chat_template": None} + try: + for fname in ("tokenizer_config.json", "tokenizer.json"): + fp = root / fname + if fp.exists() and fp.is_file(): + try: + obj = _json.loads(fp.read_text(encoding="utf-8", errors="ignore")) + except Exception: + obj = None + if isinstance(obj, dict): + ct = obj.get("chat_template") + if isinstance(ct, str) and ct.strip(): + hints["chat_template"] = ct + break + except Exception: + pass + return hints + + +def _has_any(path: Path, patterns: tuple[str, ...]) -> bool: + try: + for pat in patterns: + if any(path.glob(pat)): + return True + except Exception: + return False + return False + + +def detect_framework(hf_name: str, model_root: Path, selected_path: Optional[Path] = None, fm: Optional[FrontMatter] = None) -> str: + """Lenient framework detection. + + MLX if: + - org is mlx-community/*, or + - README front-matter tags include 'mlx', or + - README front-matter library_name == 'mlx'. + + Else GGUF if any *.gguf present under selected_path or snapshots. + Else PyTorch if any *.safetensors or pytorch_model.bin present under snapshots. + Else Unknown. + """ + try: + if "mlx-community/" in hf_name: + return "MLX" + + # Front-matter signals + if fm is not None: + tags = [t.lower() for t in (fm.tags or [])] + lib = (fm.library_name or "").lower() + if "mlx" in tags or lib == "mlx": + return "MLX" + + # Search location preference: selected snapshot, else model root + root = selected_path if selected_path is not None else model_root + + if _has_any(root, ("**/*.gguf",)): + return "GGUF" + + # Look under snapshots for common formats + snapshots_dir = model_root / "snapshots" + if _has_any(snapshots_dir, ("**/*.safetensors", "**/pytorch_model.bin")): + return "PyTorch" + except Exception: + pass + return "Unknown" + + +def detect_model_type(hf_name: str, config: Optional[Dict[str, Any]], tok_hints: Dict[str, Any]) -> str: + name = hf_name.lower() + if "embed" in name: + return "embedding" + if (config or {}).get("model_type") == "chat": + return "chat" + ct = tok_hints.get("chat_template") + if isinstance(ct, str) and ct.strip(): + return "chat" + if "instruct" in name or "chat" in name: + return "chat" + return "base" + + +def detect_capabilities(model_type: str, hf_name: str, tok_hints: Dict[str, Any], config: Optional[Dict[str, Any]]) -> list[str]: + if model_type == "embedding": + return ["embeddings"] + caps = ["text-generation"] + name = hf_name.lower() + ct = tok_hints.get("chat_template") + if model_type == "chat" or "instruct" in name or "chat" in name or (isinstance(ct, str) and ct.strip()): + caps.append("chat") + return caps + + +def _iso8601_utc_from_mtime(p: Path) -> str: + try: + from datetime import datetime + return datetime.fromtimestamp(p.stat().st_mtime).strftime("%Y-%m-%dT%H:%M:%SZ") + except Exception: + return "1970-01-01T00:00:00Z" + + +def _total_size_bytes(path: Path) -> int: + try: + total = 0 + for f in path.rglob("*"): + if f.is_file(): + total += f.stat().st_size + return total + except Exception: + return 0 + + +def _load_config_json(path: Path) -> Optional[Dict[str, Any]]: + try: + fp = path / "config.json" + if fp.exists(): + return _json.loads(fp.read_text(encoding="utf-8", errors="ignore")) + except Exception: + pass + return None + + +def build_model_object(hf_name: str, model_root: Path, selected_path: Optional[Path]) -> Dict[str, Any]: + """Build the common model object for list/show using unified detection. + + selected_path: points at the chosen snapshot directory when available; otherwise + may be the model_root. Commit hash is taken from selected_path.name if it looks + like a 40-char hex string, else None. + """ + from ..operations.health import is_model_healthy # local import to avoid cycle + + # Compute commit hash if selected path is a snapshot dir + commit_hash: Optional[str] = None + if selected_path is not None: + name = selected_path.name + if len(name) == 40 and all(c in "0123456789abcdef" for c in name.lower()): + commit_hash = name + + # Read hints from selected snapshot if possible; fall back to model root + probe = selected_path if selected_path is not None else model_root + fm = read_front_matter(probe) + tok = read_tokenizer_hints(probe) + config = _load_config_json(probe) + + framework = detect_framework(hf_name, model_root, selected_path=selected_path, fm=fm) + model_type = detect_model_type(hf_name, config, tok) + capabilities = detect_capabilities(model_type, hf_name, tok, config) + + # Health: rely on existing operation (name-based) + healthy, _reason = is_model_healthy(hf_name) + + # Size/Modified computed from selected path (snapshot preferred) + base = selected_path if selected_path is not None else model_root + model_obj = { + "name": hf_name, + "hash": commit_hash, + "size_bytes": _total_size_bytes(base), + "last_modified": _iso8601_utc_from_mtime(base), + "framework": framework, + "model_type": model_type, + "capabilities": capabilities, + "health": "healthy" if healthy else "unhealthy", + "cached": True, + } + return model_obj diff --git a/mlxk2/operations/list.py b/mlxk2/operations/list.py index 995131f..23df761 100644 --- a/mlxk2/operations/list.py +++ b/mlxk2/operations/list.py @@ -1,21 +1,9 @@ """List models operation for MLX-Knife 2.0.""" -from datetime import datetime from typing import Dict, Any, Optional, Tuple from ..core.cache import get_current_model_cache, cache_dir_to_hf -from .health import is_model_healthy - - -def _total_size_bytes(model_path) -> int: - """Calculate total model size in bytes for a given path.""" - if not model_path.exists(): - return 0 - total_size = 0 - for file in model_path.rglob("*"): - if file.is_file(): - total_size += file.stat().st_size - return total_size +from .common import build_model_object def _latest_snapshot(model_path) -> Tuple[Optional[str], Optional[object]]: @@ -30,48 +18,6 @@ def _latest_snapshot(model_path) -> Tuple[Optional[str], Optional[object]]: return latest.name, latest -def detect_framework(model_path, hf_name): - """Detect model framework without exposing internal logic.""" - if "mlx-community" in hf_name: - return "MLX" - - # Check for GGUF files - if list(model_path.glob("**/*.gguf")): - return "GGUF" - - # Check for common formats - snapshots_dir = model_path / "snapshots" - if snapshots_dir.exists(): - has_safetensors = any(snapshots_dir.glob("**/*.safetensors")) - has_pytorch_bin = any(snapshots_dir.glob("**/pytorch_model.bin")) - - if has_safetensors: - return "PyTorch" - elif has_pytorch_bin: - return "PyTorch" - - return "Unknown" - - -def detect_model_type(hf_name: str) -> str: - n = hf_name.lower() - if "embed" in n: - return "embedding" - if "instruct" in n or "chat" in n: - return "chat" - return "base" - - -def detect_capabilities(hf_name: str) -> list: - n = hf_name.lower() - if "embed" in n: - return ["embeddings"] - caps = ["text-generation"] - if "instruct" in n or "chat" in n: - caps.append("chat") - return caps - - def list_models(pattern: str = None) -> Dict[str, Any]: """List all models in cache with JSON output. @@ -107,25 +53,10 @@ def list_models(pattern: str = None) -> Dict[str, Any]: if pattern.lower() not in hf_name.lower(): continue - # Select snapshot (prefer latest) and compute fields - commit_hash, snap_path = _latest_snapshot(model_dir) - selected_path = snap_path if snap_path is not None else model_dir - last_modified = datetime.fromtimestamp(selected_path.stat().st_mtime).strftime("%Y-%m-%dT%H:%M:%SZ") - size_bytes = _total_size_bytes(selected_path) - healthy, _reason = is_model_healthy(hf_name) - - # Minimal model object per spec 0.1.2 - models.append({ - "name": hf_name, - "hash": commit_hash, - "size_bytes": size_bytes, - "last_modified": last_modified, - "framework": detect_framework(model_dir, hf_name), - "model_type": detect_model_type(hf_name), - "capabilities": detect_capabilities(hf_name), - "health": "healthy" if healthy else "unhealthy", - "cached": True, - }) + # Select snapshot (prefer latest) and build model object + _hash, snap_path = _latest_snapshot(model_dir) + model_obj = build_model_object(hf_name, model_dir, snap_path if snap_path is not None else model_dir) + models.append(model_obj) # Sort by name for consistent output models.sort(key=lambda x: x["name"]) diff --git a/mlxk2/operations/push.py b/mlxk2/operations/push.py index 6b6be1f..d211ac4 100644 --- a/mlxk2/operations/push.py +++ b/mlxk2/operations/push.py @@ -13,7 +13,7 @@ from __future__ import annotations import os from pathlib import Path -from typing import Dict, Any, List, Tuple, Optional +from typing import Dict, Any, List, Optional import json as _json @@ -163,7 +163,7 @@ def push_operation( # 4) Ensure repo exists (model type). Do not auto-create branch here. created_repo = False try: - # If branch does not exist, this may raise; that is acceptable for M0. + # If branch does not exist, this raises RevisionNotFoundError. api.repo_info(repo_id=repo_id, repo_type="model", revision=branch) except RepositoryNotFoundError: if dry_run: @@ -187,14 +187,25 @@ def push_operation( "message": f"Repository not found: {repo_id} (use --create)", } return result - # Try create + # Try create repository (exist_ok=True covers races) api.create_repo( repo_id=repo_id, repo_type="model", private=private, exist_ok=True ) # After create, no guarantee branch exists; upload_folder below will target revision created_repo = True + # Ensure target branch exists if not default + try: + if branch and branch != DEFAULT_PUSH_BRANCH: + api.create_branch(repo_id=repo_id, repo_type="model", branch=branch) + except HfHubHTTPError as e: + result["status"] = "error" + result["error"] = { + "type": "branch_create_failed", + "message": str(e), + } + return result except RevisionNotFoundError: - # Repo exists but branch doesn't; allow upload_folder to create the branch/commit. + # Repo exists but branch doesn't. if dry_run: local_files = _collect_local_files(p, ignore_patterns) result["data"].update({ @@ -208,12 +219,18 @@ def push_operation( "would_create_branch": True, }) return result - pass + # If user asked to create, proactively create the branch to avoid 404 on preupload; + # otherwise, tolerate and let upload_folder attempt (offline tests expect this). + if create: + try: + api.create_branch(repo_id=repo_id, repo_type="model", branch=branch) + except HfHubHTTPError: + # Do not fail early; fall through and let upload attempt once + pass # 4b) If dry-run and repo/branch exist: compute diff vs remote and return if dry_run: try: - from fnmatch import fnmatch remote_files = set(api.list_repo_files(repo_id=repo_id, repo_type="model", revision=branch or DEFAULT_PUSH_BRANCH) or []) except Exception: remote_files = set() @@ -246,7 +263,6 @@ def push_operation( try: import logging as _logging import contextlib as _contextlib - import sys as _sys _hf_logger = _logging.getLogger("huggingface_hub") class _BufHandler(_logging.Handler): @@ -280,20 +296,8 @@ def push_operation( _hf_logger.handlers = [_handler] # keep only our buffer in quiet mode # Silence tqdm progress bars to stderr as an extra safety in quiet mode - if quiet: - with open(os.devnull, "w") as _devnull: - with _contextlib.redirect_stderr(_devnull): - info = upload_folder( - repo_id=repo_id, - repo_type="model", - folder_path=str(p), - revision=branch or DEFAULT_PUSH_BRANCH, - commit_message=commit_msg, - token=hf_token, - ignore_patterns=ignore_patterns, - ) - else: - info = upload_folder( + def _do_upload(): + return upload_folder( repo_id=repo_id, repo_type="model", folder_path=str(p), @@ -302,6 +306,13 @@ def push_operation( token=hf_token, ignore_patterns=ignore_patterns, ) + + if quiet: + with open(os.devnull, "w") as _devnull: + with _contextlib.redirect_stderr(_devnull): + info = _do_upload() + else: + info = _do_upload() hf_logs = getattr(_handler, "buf", None) finally: # Restore logger state @@ -325,12 +336,39 @@ def push_operation( except Exception: pass except HfHubHTTPError as he: - result["status"] = "error" - result["error"] = { - "type": "upload_failed", - "message": str(he), - } - return result + # In some hub versions, uploading to a non-existent branch raises here. + # If --create was given, try to create the branch and retry once. + msg = str(he) + if create and ("Revision Not Found" in msg or "Invalid rev id" in msg): + try: + api.create_branch(repo_id=repo_id, repo_type="model", branch=branch) + # Retry upload once + try: + info = upload_folder( + repo_id=repo_id, + repo_type="model", + folder_path=str(p), + revision=branch or DEFAULT_PUSH_BRANCH, + commit_message=commit_msg, + token=hf_token, + ignore_patterns=ignore_patterns, + ) + hf_logs = hf_logs or [] + except HfHubHTTPError as he2: + result["status"] = "error" + result["error"] = {"type": "upload_failed", "message": str(he2)} + return result + except HfHubHTTPError as ce: + result["status"] = "error" + result["error"] = {"type": "branch_create_failed", "message": str(ce)} + return result + else: + result["status"] = "error" + result["error"] = { + "type": "upload_failed", + "message": str(he), + } + return result except Exception as e: result["status"] = "error" result["error"] = { diff --git a/mlxk2/operations/show.py b/mlxk2/operations/show.py index 4cd89e6..66a91e2 100644 --- a/mlxk2/operations/show.py +++ b/mlxk2/operations/show.py @@ -1,12 +1,11 @@ """Show model operation for MLX-Knife 2.0.""" import json -from datetime import datetime from typing import Dict, Any -from ..core.cache import MODEL_CACHE, hf_to_cache_dir +from ..core.cache import get_current_model_cache, hf_to_cache_dir from ..core.model_resolution import resolve_model_for_operation -from .health import is_model_healthy +from .common import build_model_object def get_file_type(file_name): @@ -109,60 +108,8 @@ def get_config_content(model_path): return None -def detect_model_capabilities(hf_name, config_data): - """Detect model capabilities from name and config.""" - capabilities = [] - - # Check for embedding models - if "embed" in hf_name.lower(): - capabilities.append("embeddings") - else: - capabilities.append("text-generation") - - # Check for chat/instruct models - if any(keyword in hf_name.lower() for keyword in ["instruct", "chat"]): - capabilities.append("chat") - - return capabilities - - -def detect_model_type(hf_name, config_data): - """Detect high-level model type.""" - if "embed" in hf_name.lower(): - return "embedding" - elif any(keyword in hf_name.lower() for keyword in ["instruct", "chat"]): - return "chat" - else: - return "base" - - -def detect_framework(model_path, hf_name: str) -> str: - """Detect model framework similarly to list operation.""" - if "mlx-community" in hf_name: - return "MLX" - # GGUF files - if list(model_path.glob("**/*.gguf")): - return "GGUF" - # PyTorch/safetensors - snapshots_dir = model_path / "snapshots" - if snapshots_dir.exists(): - has_safetensors = any(snapshots_dir.glob("**/*.safetensors")) - has_pytorch_bin = any(snapshots_dir.glob("**/pytorch_model.bin")) - if has_safetensors or has_pytorch_bin: - return "PyTorch" - return "Unknown" - - -def get_total_size_bytes(model_path): - """Calculate total model size in bytes.""" - if not model_path.exists(): - return 0 - - total_size = 0 - for file_path in model_path.rglob("*"): - if file_path.is_file(): - total_size += file_path.stat().st_size - return total_size +def _is_40_hex(s: str) -> bool: + return len(s) == 40 and all(c in "0123456789abcdef" for c in s.lower()) def show_model_operation(model_pattern: str, include_files: bool = False, include_config: bool = False) -> Dict[str, Any]: @@ -196,7 +143,7 @@ def show_model_operation(model_pattern: str, include_files: bool = False, includ return result # Get model directory - model_cache_dir = MODEL_CACHE / hf_to_cache_dir(resolved_name) + model_cache_dir = get_current_model_cache() / hf_to_cache_dir(resolved_name) if not model_cache_dir.exists(): result["status"] = "error" result["error"] = { @@ -231,35 +178,17 @@ def show_model_operation(model_pattern: str, include_files: bool = False, includ if not model_path: model_path = model_cache_dir - # Get health status - healthy, health_reason = is_model_healthy(resolved_name) - - # Calculate size in bytes - total_size_bytes = get_total_size_bytes(model_path) - - # Get config data for metadata - config_data = get_config_content(model_path) - + # Build unified model object + model_obj = build_model_object(resolved_name, model_cache_dir, model_path) + # Build response data - data = { - "model": { - "name": resolved_name, - "hash": commit_hash, - "size_bytes": total_size_bytes, - "last_modified": datetime.fromtimestamp(model_path.stat().st_mtime).strftime("%Y-%m-%dT%H:%M:%SZ"), - "framework": detect_framework(model_cache_dir, resolved_name), - "model_type": detect_model_type(resolved_name, config_data), - "capabilities": detect_model_capabilities(resolved_name, config_data), - "health": "healthy" if healthy else "unhealthy", - "cached": True, - } - } + data = {"model": model_obj} if include_files: data["files"] = get_model_files(model_path) data["metadata"] = None elif include_config: - data["config"] = config_data + data["config"] = get_config_content(model_path) data["metadata"] = None else: data["metadata"] = extract_model_metadata(model_path) diff --git a/mlxk2/output/human.py b/mlxk2/output/human.py index 0818b0f..3c019f6 100644 --- a/mlxk2/output/human.py +++ b/mlxk2/output/human.py @@ -82,11 +82,26 @@ def render_list(data: Dict[str, Any], show_health: bool, show_all: bool, verbose if show_health: headers.append("Health") - # Human filter: by default only show MLX framework; with --all show everything + # Human filter: + # - --all: show everything + # - default: show only MLX chat models (safer for run/server selection) + # - --verbose (without --all): show all MLX models (chat + base) filtered: List[Dict[str, Any]] = [] for m in models: - if show_all or str(m.get("framework", "")).upper() == "MLX": + fw = str(m.get("framework", "")).upper() + typ = str(m.get("model_type", "")).lower() + if show_all: filtered.append(m) + else: + if fw != "MLX": + continue + if verbose: + # In verbose mode, show all MLX models + filtered.append(m) + else: + # Default compact mode: only MLX chat + if typ == "chat": + filtered.append(m) rows: List[List[str]] = [] for m in filtered: diff --git a/pyproject-mlxk-json.toml b/pyproject-mlxk-json.toml index c24d4b9..5606b69 100644 --- a/pyproject-mlxk-json.toml +++ b/pyproject-mlxk-json.toml @@ -8,7 +8,7 @@ version = "2.0.0-alpha" description = "MLX-Knife 2.0 - JSON-first model management for automation" readme = "README.md" requires-python = ">=3.9" -license = {text = "MIT"} +license = {text = "Apache-2.0"} authors = [ {name = "The BROKE team", email = "broke@gmx.eu"}, ] @@ -24,6 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Operating System :: MacOS", "Environment :: Console", + "License :: OSI Approved :: Apache Software License", ] dependencies = [ "huggingface-hub>=0.34.0", @@ -43,4 +44,10 @@ include = ["mlxk2*"] exclude = ["tests*", "tests_2.0*"] [tool.setuptools.dynamic] -version = {attr = "mlxk2.__version__"} \ No newline at end of file +version = {attr = "mlxk2.__version__"} + +[tool.setuptools] +license-files = [ + "LICENSE", + "mlxk2/NOTICE", +] diff --git a/pyproject.toml b/pyproject.toml index ec3be40..f459447 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "MLX-Knife 2.0 - JSON-first model management for automation" readme = "README.md" requires-python = ">=3.9" -license = {text = "MIT"} +license = {text = "Apache-2.0"} authors = [ {name = "The BROKE team", email = "broke@gmx.eu"}, ] @@ -24,6 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Operating System :: MacOS", "Environment :: Console", + "License :: OSI Approved :: Apache Software License", ] dependencies = [ "huggingface-hub>=0.34.0", @@ -50,3 +51,9 @@ test = [ "pytest>=7", "jsonschema>=4.20", ] + +[tool.setuptools] +license-files = [ + "LICENSE", + "mlxk2/NOTICE", +] diff --git a/pytest.ini b/pytest.ini index a4f489f..93b8bf4 100644 --- a/pytest.ini +++ b/pytest.ini @@ -7,3 +7,4 @@ markers = spec: JSON API contract tests (current spec only) wet: Opt-in live tests against Hugging Face (require env) live_push: Alias for wet; push live tests (require env) + live_list: Alias for wet; list human live tests (require env) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index c0cff0a..0000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""MLX Knife Test Suite""" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 6fb8024..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,196 +0,0 @@ -""" -Pytest configuration and shared fixtures for MLX Knife tests. -""" -import os -import tempfile -import shutil -import pytest -import subprocess -import signal -import time -from pathlib import Path -from typing import Generator, List -import psutil - - -@pytest.fixture -def temp_cache_dir() -> Generator[Path, None, None]: - """Create a temporary cache directory for isolated testing.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_path = Path(temp_dir) / "test_cache" - cache_path.mkdir() - - # Create hub subdirectory (required by HF_HOME/hub fix) - hub_path = cache_path / "hub" - hub_path.mkdir() - - # Set HF_HOME to our temp directory - old_hf_home = os.environ.get("HF_HOME") - os.environ["HF_HOME"] = str(cache_path) - - try: - yield cache_path - finally: - # Restore original HF_HOME - if old_hf_home: - os.environ["HF_HOME"] = old_hf_home - elif "HF_HOME" in os.environ: - del os.environ["HF_HOME"] - - -@pytest.fixture(scope="class") -def class_temp_cache_dir() -> Generator[Path, None, None]: - """Create a temporary cache directory for class-level testing (setup_class/teardown_class).""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_path = Path(temp_dir) / "test_cache" - cache_path.mkdir() - - # Create hub subdirectory (required by HF_HOME/hub fix) - hub_path = cache_path / "hub" - hub_path.mkdir() - - # Set HF_HOME to our temp directory - old_hf_home = os.environ.get("HF_HOME") - os.environ["HF_HOME"] = str(cache_path) - - try: - yield cache_path - finally: - # Restore original HF_HOME - if old_hf_home: - os.environ["HF_HOME"] = old_hf_home - elif "HF_HOME" in os.environ: - del os.environ["HF_HOME"] - - -@pytest.fixture -def patch_model_cache(): - """Utility fixture to temporarily patch MODEL_CACHE to isolated directory.""" - from contextlib import contextmanager - - @contextmanager - def _patch_cache(cache_path: Path): - from mlx_knife import cache_utils - original_cache = cache_utils.MODEL_CACHE - cache_utils.MODEL_CACHE = cache_path - try: - yield cache_path - finally: - cache_utils.MODEL_CACHE = original_cache - - return _patch_cache - - -@pytest.fixture -def mlx_knife_process(): - """Factory fixture to create and manage mlx_knife subprocess.""" - processes: List[subprocess.Popen] = [] - - def _create_process(args: List[str], **kwargs) -> subprocess.Popen: - """Create a new mlx_knife process and track it.""" - full_args = ["python", "-m", "mlx_knife.cli"] + args - proc = subprocess.Popen( - full_args, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - **kwargs - ) - processes.append(proc) - return proc - - yield _create_process - - # Cleanup: Kill all created processes - for proc in processes: - if proc.poll() is None: # Process still running - try: - proc.terminate() - proc.wait(timeout=5) - except subprocess.TimeoutExpired: - proc.kill() - proc.wait() - - -@pytest.fixture -def process_monitor(): - """Monitor processes for zombie detection.""" - def _get_process_tree(pid: int) -> List[psutil.Process]: - """Get all child processes of a given PID.""" - try: - parent = psutil.Process(pid) - return parent.children(recursive=True) - except psutil.NoSuchProcess: - return [] - - def _wait_for_process_cleanup(pid: int, timeout: float = 5.0) -> bool: - """Wait for all child processes to terminate.""" - start_time = time.time() - while time.time() - start_time < timeout: - children = _get_process_tree(pid) - if not children: - return True - time.sleep(0.1) - return False - - return { - "get_process_tree": _get_process_tree, - "wait_for_cleanup": _wait_for_process_cleanup - } - - -@pytest.fixture -def mock_model_cache(temp_cache_dir): - """Create mock model cache structures for testing.""" - def _create_mock_model( - model_name: str, - healthy: bool = True, - corruption_type: str = None - ) -> Path: - """Create a mock model in the cache directory.""" - # Convert model name to cache directory format - cache_name = model_name.replace("/", "--") - # Create models in hub subdirectory (HF_HOME/hub fix) - hub_dir = temp_cache_dir / "hub" - model_dir = hub_dir / f"models--{cache_name}" / "snapshots" / "main" - model_dir.mkdir(parents=True, exist_ok=True) - - if healthy and not corruption_type: - # Create healthy model files - (model_dir / "config.json").write_text('{"model_type": "test"}') - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - (model_dir / "model.safetensors").write_bytes(b"fake_model_data" * 100) - elif corruption_type: - _create_corrupted_model(model_dir, corruption_type) - - return model_dir - - def _create_corrupted_model(model_dir: Path, corruption_type: str): - """Create various types of corrupted models.""" - if corruption_type == "missing_snapshot": - # Remove snapshots directory - shutil.rmtree(model_dir.parent.parent) - elif corruption_type == "missing_config": - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - (model_dir / "model.safetensors").write_bytes(b"fake_model_data") - # config.json is missing - elif corruption_type == "lfs_pointer": - (model_dir / "config.json").write_text('{"model_type": "test"}') - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - # Create LFS pointer file instead of actual data - (model_dir / "model.safetensors").write_text( - "version https://git-lfs.github.com/spec/v1\n" - "oid sha256:abc123\n" - "size 1000000\n" - ) - elif corruption_type == "truncated_safetensors": - (model_dir / "config.json").write_text('{"model_type": "test"}') - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - # Create truncated/corrupted safetensors - (model_dir / "model.safetensors").write_bytes(b"corrupted") - elif corruption_type == "missing_tokenizer": - (model_dir / "config.json").write_text('{"model_type": "test"}') - (model_dir / "model.safetensors").write_bytes(b"fake_model_data") - # tokenizer.json is missing - - return _create_mock_model \ No newline at end of file diff --git a/tests/integration/test_core_functionality.py b/tests/integration/test_core_functionality.py deleted file mode 100644 index e2d6142..0000000 --- a/tests/integration/test_core_functionality.py +++ /dev/null @@ -1,319 +0,0 @@ -""" -High Priority Tests: Core Functionality - -Tests ensure primary features work correctly: -- Model execution (run command, streaming, token decoding, stop tokens) -- Basic operations (list, show, pull, rm) -- Chat template application -""" -import pytest -import subprocess -import json -import time -from pathlib import Path -from unittest.mock import patch, MagicMock - - -@pytest.mark.timeout(30) -class TestBasicOperations: - """Test core CLI operations.""" - - def test_list_command_empty_cache(self, mlx_knife_process, temp_cache_dir): - """List command should handle empty cache gracefully.""" - proc = mlx_knife_process(["list"]) - stdout, stderr = proc.communicate(timeout=10) - - # Should complete successfully - assert proc.returncode == 0, f"List failed on empty cache: {stderr}" - - # Should produce some output (even if empty list) - assert len(stdout) >= 0 - # Common outputs for empty cache: "No models found" or empty list - - def test_list_command_with_models(self, mlx_knife_process, mock_model_cache): - """List command should display available models.""" - # Create some mock models - mock_model_cache("test-model-1", healthy=True) - mock_model_cache("test-model-2", healthy=True) - - proc = mlx_knife_process(["list"]) - stdout, stderr = proc.communicate(timeout=10) - - assert proc.returncode == 0, f"List failed: {stderr}" - assert len(stdout) > 0, "List produced no output with models present" - - # Should contain reference to models (exact format depends on implementation) - output_lower = stdout.lower() - assert "test" in output_lower or "model" in output_lower or len(stdout.split('\n')) > 1 - - def test_show_command_existing_model(self, mlx_knife_process, mock_model_cache): - """Show command should display model details.""" - model_dir = mock_model_cache("test-model", healthy=True) - - # Try different possible model name formats - model_names_to_try = ["test-model", "test/model", "models--test-model"] - - success = False - for model_name in model_names_to_try: - proc = mlx_knife_process(["show", model_name]) - stdout, stderr = proc.communicate(timeout=10) - - if proc.returncode == 0 and len(stdout) > 0: - success = True - break - - # At least one format should work, or command should handle gracefully - # The key is that it doesn't crash or hang - assert success or all( - proc.returncode is not None for proc in [ - mlx_knife_process(["show", name]) - for name in model_names_to_try - ] - ), "Show command hung or crashed" - - def test_show_command_nonexistent_model(self, mlx_knife_process, temp_cache_dir): - """Show command should handle nonexistent models gracefully.""" - proc = mlx_knife_process(["show", "nonexistent-model"]) - stdout, stderr = proc.communicate(timeout=10) - - # Should complete (likely with error code) - assert proc.returncode is not None, "Show command hung" - - # Should produce some error message - output = stdout + stderr - assert len(output) > 0, "No error message for nonexistent model" - - def test_rm_command_safety(self, mlx_knife_process, temp_cache_dir): - """Remove command should handle nonexistent models safely.""" - proc = mlx_knife_process(["rm", "nonexistent-model"]) - stdout, stderr = proc.communicate(timeout=10) - - # Should complete (may succeed or fail gracefully) - assert proc.returncode is not None, "Remove command hung" - - # Should not crash - # Exact behavior depends on implementation - - def test_rm_command_corrupted_empty_snapshots(self, mlx_knife_process, temp_cache_dir): - """Remove command should handle corrupted models with empty snapshots directory.""" - from mlx_knife.cache_utils import hf_to_cache_dir - - # Create a corrupted model structure (directory exists but snapshots is empty) - test_model = "test-org/corrupted-empty-model" - # Create in hub subdirectory (new cache structure) - hub_dir = temp_cache_dir / "hub" - cache_dir = hub_dir / hf_to_cache_dir(test_model) - cache_dir.mkdir(parents=True, exist_ok=True) - (cache_dir / "snapshots").mkdir(exist_ok=True) - (cache_dir / "blobs").mkdir(exist_ok=True) - (cache_dir / "refs").mkdir(exist_ok=True) - - try: - # This should NOT fail silently - should either provide error message or handle deletion - # Use --force to avoid hanging on input prompts in test environment - proc = mlx_knife_process(["rm", test_model, "--force"]) - stdout, stderr = proc.communicate(timeout=10) - - # Should complete (not hang) - assert proc.returncode is not None, "Remove command hung on corrupted model" - - # Should produce SOME output (not silent failure) - output = (stdout + stderr).strip() - assert len(output) > 0, "Remove command failed silently on corrupted model - no output produced" - - # The behavior should be explicit: either error message or deletion prompt/confirmation - output_lower = output.lower() - has_error = "error" in output_lower or "not found" in output_lower - has_prompt = "delete" in output_lower or "remove" in output_lower - - assert has_error or has_prompt, f"Remove command should provide clear feedback, got: {output}" - - finally: - # Cleanup - remove the test corrupted model structure - import shutil - if cache_dir.exists(): - shutil.rmtree(cache_dir) - - -@pytest.mark.timeout(60) -class TestModelExecution: - """Test model loading and execution functionality.""" - - def test_run_command_basic_prompt(self, mlx_knife_process): - """Test basic model execution with prompt using real MLX model.""" - # Uses Phi-3-mini-4k-instruct-4bit (assumes already pulled and healthy) - test_model = "Phi-3-mini-4k-instruct-4bit" - test_prompt = "Say hello." - - proc = mlx_knife_process(["run", test_model, test_prompt, "--max-tokens", "20"]) - stdout, stderr = proc.communicate(timeout=60) - - # Test MLX Knife functionality, not model quality - assert proc.returncode == 0, f"MLX Knife execution failed: {stderr}" - assert len(stdout.strip()) > 0, "MLX Knife produced no output - model loading/generation failed" - assert len(stdout.strip()) < 1000, f"MLX Knife did not respect max-tokens limit: {len(stdout)} chars" - - # Basic sanity check: output should be reasonable text (not binary garbage) - # Allow common whitespace characters (newlines, tabs, spaces) - clean_output = stdout.replace('\n', '').replace('\t', '').replace('\r', '') - assert clean_output.isprintable(), f"MLX Knife produced non-printable output: {repr(stdout)}" - - def test_run_command_invalid_model(self, mlx_knife_process, temp_cache_dir): - """Run command should handle invalid models gracefully.""" - proc = mlx_knife_process(["run", "nonexistent-model", "test prompt"]) - stdout, stderr = proc.communicate(timeout=15) - - # Should fail gracefully, not hang - assert proc.returncode is not None, "Run command hung on invalid model" - assert proc.returncode != 0, "Run should fail on nonexistent model" - - # Should produce error message - output = stdout + stderr - assert len(output) > 0, "No error message for invalid model" - - def test_streaming_token_generation(self, mlx_knife_process): - """Test streaming token output with real MLX model.""" - test_model = "Phi-3-mini-4k-instruct-4bit" - test_prompt = "Write the word 'test' three times." - - proc = mlx_knife_process(["run", test_model, test_prompt, "--max-tokens", "30"]) - stdout, stderr = proc.communicate(timeout=45) - - # Test MLX Knife streaming functionality, not model accuracy - assert proc.returncode == 0, f"MLX Knife streaming failed: {stderr}" - assert len(stdout.strip()) > 0, "MLX Knife streaming produced no output" - assert len(stdout.strip()) < 2000, f"MLX Knife streaming did not respect token limits: {len(stdout)} chars" - - # Verify streaming worked by checking output is reasonable text - # Allow common whitespace characters (newlines, tabs, spaces) - clean_output = stdout.replace('\n', '').replace('\t', '').replace('\r', '') - assert clean_output.isprintable(), f"MLX Knife streaming produced non-printable output: {repr(stdout)}" - - - -@pytest.mark.timeout(120) -class TestPullOperation: - """Test model downloading functionality.""" - - def test_pull_command_invalid_model(self, mlx_knife_process, temp_cache_dir): - """Pull command should handle invalid model names gracefully.""" - proc = mlx_knife_process(["pull", "definitely-not-a-real-model-12345"]) - stdout, stderr = proc.communicate(timeout=30) - - # Should fail, not hang - assert proc.returncode is not None, "Pull command hung" - assert proc.returncode != 0, "Pull should fail on invalid model" - - # Should produce error message - output = stdout + stderr - assert len(output) > 0, "No error message for invalid model" - - def test_pull_command_network_timeout_handling(self, mlx_knife_process, temp_cache_dir, patch_model_cache): - """Pull command should handle network issues gracefully - uses isolated cache.""" - # Use Phi-3-mini for realistic timeout testing, but in ISOLATED cache - with patch_model_cache(temp_cache_dir / "hub"): - proc = mlx_knife_process(["pull", "mlx-community/Phi-3-mini-4k-instruct-4bit", "--no-progress"]) - - # Give it limited time to start, then interrupt - time.sleep(5) - - if proc.poll() is None: # Still running - proc.send_signal(subprocess.signal.SIGINT) - try: - stdout, stderr = proc.communicate(timeout=15) - except subprocess.TimeoutExpired: - proc.kill() - stdout, stderr = proc.communicate() - else: - stdout, stderr = proc.communicate() - - # Key test: should not hang indefinitely - assert proc.returncode is not None, "Pull command did not terminate" - - # Should handle interruption gracefully - output = stdout + stderr - assert len(output) >= 0 # Some output expected - - print("✓ Timeout test completed - any broken Phi-3-mini in isolated cache will be auto-cleaned") - - -@pytest.mark.timeout(30) -class TestCommandLineInterface: - """Test CLI argument parsing and help functionality.""" - - def test_help_command(self, mlx_knife_process): - """Help command should display usage information.""" - proc = mlx_knife_process(["--help"]) - stdout, stderr = proc.communicate(timeout=10) - - # Should succeed - assert proc.returncode == 0, f"Help command failed: {stderr}" - - # Should produce help output - assert len(stdout) > 0, "Help produced no output" - - # Should contain basic command information - help_text = stdout.lower() - assert any(cmd in help_text for cmd in ["list", "pull", "run", "health"]), \ - "Help missing core commands" - - def test_version_command(self, mlx_knife_process): - """Version command should display version information.""" - # Try common version flags - version_flags = ["--version", "-v"] - - success = False - for flag in version_flags: - try: - proc = mlx_knife_process([flag]) - stdout, stderr = proc.communicate(timeout=10) - - if proc.returncode == 0 and len(stdout) > 0: - success = True - # Should contain version number - assert any(char.isdigit() for char in stdout), \ - "Version output contains no digits" - break - except: - continue - - # At least one version flag should work, or command should handle gracefully - if not success: - # Test that invalid flags are handled - proc = mlx_knife_process(["--invalid-flag"]) - stdout, stderr = proc.communicate(timeout=10) - assert proc.returncode is not None, "Invalid flag handling hung" - - def test_invalid_command_handling(self, mlx_knife_process): - """Invalid commands should be handled gracefully.""" - proc = mlx_knife_process(["invalid-command-xyz"]) - stdout, stderr = proc.communicate(timeout=10) - - # Should fail but not hang - assert proc.returncode is not None, "Invalid command hung" - assert proc.returncode != 0, "Invalid command should not succeed" - - # Should produce error message - output = stdout + stderr - assert len(output) > 0, "No error message for invalid command" - - def test_missing_arguments_handling(self, mlx_knife_process): - """Commands missing required arguments should fail gracefully.""" - # Test commands that require arguments - commands_needing_args = [ - ["run"], # needs model and prompt - ["show"], # needs model name - ["pull"], # needs model name - ] - - for cmd in commands_needing_args: - proc = mlx_knife_process(cmd) - stdout, stderr = proc.communicate(timeout=10) - - # Should fail gracefully - assert proc.returncode is not None, f"Command {cmd} hung" - assert proc.returncode != 0, f"Command {cmd} should fail without required args" - - # Should produce helpful error - output = stdout + stderr - assert len(output) > 0, f"No error message for {cmd} without args" \ No newline at end of file diff --git a/tests/integration/test_end_token_issue.py b/tests/integration/test_end_token_issue.py deleted file mode 100644 index 112959a..0000000 --- a/tests/integration/test_end_token_issue.py +++ /dev/null @@ -1,534 +0,0 @@ -""" -Test for End-Token Issue: Streaming vs Non-Streaming Consistency - -This test ensures that End-Tokens are handled consistently across different -models and streaming modes using actual token metrics instead of word estimates. -""" - -import logging -import signal -import subprocess -import time -from typing import Dict, List, Tuple, Any -import json - -import psutil -import pytest -import requests - -logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') -logger = logging.getLogger(__name__) - -# Realistic RAM requirements for 4-bit quantized models (in GB) -MODEL_RAM_REQUIREMENTS = { - "0.5B": 1, "1B": 2, "3B": 4, "4B": 5, - "7B": 8, "8x7B": 16, "24B": 20, "30B": 24, - "70B": 40, "480B": 180 -} - -# Model-specific End-Tokens to check for (comprehensive list) -MODEL_END_TOKENS = { - "llama": ["", "<|end_of_text|>", "<|eot_id|>"], # Llama-2/3.x tokens - "mistral": ["", "<|endoftext|>"], # Mistral variants - "qwen": ["<|im_end|>", "<|endoftext|>", "<|end|>", ""], # Qwen variants - "phi": ["<|endoftext|>", "<|end|>", ""], # Phi-3 variants - "mixtral": ["", "<|endoftext|>"], # Mixtral (Mistral-based) - "default": [ # Comprehensive catch-all list - "", "<|im_end|>", "<|endoftext|>", "<|end_of_text|>", - "<|eot_id|>", "<|end|>", "", "", "", "", - "<|assistant|>", "<|user|>", "<|system|>" - ] -} - -SERVER_BASE_URL = "http://localhost:8000" -SERVER_PORT = 8000 - - -def extract_model_size(model_name: str) -> str: - """Extract model size from model name.""" - import re - - # Match patterns like "30B", "8x7B", "480B", "0.5B", "3.2B" - size_patterns = [ - r'(\d+(?:\.\d+)?B)', # Standard: 30B, 3.2B, 0.5B - r'(\d+x\d+B)', # MoE: 8x7B - r'(480B)', # Special: 480B - r'Phi-3-mini', # Map to 4B - r'small', # Map to 7B (lowercase) - r'Small', # Map to 7B (capitalized) - ] - - for pattern in size_patterns: - match = re.search(pattern, model_name, re.IGNORECASE) - if match: - size = match.group(1) - if 'Phi-3-mini' in size: - return '4B' - elif 'small' in size.lower(): - return '7B' - return size - - return '7B' # Default fallback - - -def get_model_family(model_name: str) -> str: - """Determine model family for End-Token selection.""" - model_lower = model_name.lower() - - if 'llama' in model_lower: - return 'llama' - elif 'mistral' in model_lower and 'mixtral' not in model_lower: - return 'mistral' - elif 'qwen' in model_lower: - return 'qwen' - elif 'phi' in model_lower: - return 'phi' - elif 'mixtral' in model_lower: - return 'mixtral' - else: - return 'default' - - -def get_available_ram_gb() -> int: - """Get available system RAM in GB.""" - return psutil.virtual_memory().available // (1024**3) - - -class MLXKnifeServerManager: - """Context manager for MLX Knife server lifecycle.""" - - def __init__(self): - self.process = None - - def __enter__(self): - self.start_server() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop_server() - - def start_server(self): - """Start MLX Knife server.""" - logger.info("Starting MLX Knife server...") - self.process = subprocess.Popen( - ["mlxk", "server", "--host", "127.0.0.1", "--port", str(SERVER_PORT)], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN) - ) - - # Wait for server to be ready - for attempt in range(30): - try: - response = requests.get(f"{SERVER_BASE_URL}/health", timeout=2) - if response.status_code == 200: - logger.info("Server is ready") - return - except: - pass - time.sleep(1) - - raise RuntimeError("Server failed to start within 30 seconds") - - def stop_server(self): - """Stop MLX Knife server with proper cleanup.""" - if self.process: - logger.info("Stopping server...") - # Graceful shutdown attempt - self.process.terminate() - try: - self.process.wait(timeout=10) - logger.info("Server stopped gracefully") - except subprocess.TimeoutExpired: - logger.warning("Server did not stop gracefully, force killing...") - self.process.kill() - self.process.wait() - logger.info("Server force killed") - - # Wait a bit for port cleanup - time.sleep(2) - - # Verify port is actually free - for attempt in range(5): - try: - response = requests.get(f"{SERVER_BASE_URL}/health", timeout=1) - if attempt == 4: - logger.warning("Port may still be occupied after server shutdown") - time.sleep(1) - except requests.exceptions.RequestException: - # Good - server is really down - logger.info("Port confirmed free") - break - - -def get_available_models() -> List[str]: - """Get list of available models from server.""" - try: - response = requests.get(f"{SERVER_BASE_URL}/v1/models", timeout=10) - if response.status_code == 200: - data = response.json() - return [model["id"] for model in data.get("data", [])] - except Exception as e: - logger.warning(f"Failed to get models: {e}") - return [] - - -def get_safe_models_for_system() -> List[Tuple[str, str, int]]: - """Get models that can safely run on current system.""" - models = get_available_models() - available_ram = get_available_ram_gb() - safe_models = [] - - for model in models: - size_str = extract_model_size(model) - ram_needed = MODEL_RAM_REQUIREMENTS.get(size_str, 8) # Default 8GB - - if ram_needed <= available_ram: - safe_models.append((model, size_str, ram_needed)) - - return safe_models - - -def get_model_context_length(model_name: str) -> int: - """Get model's context length from server.""" - try: - response = requests.get(f"{SERVER_BASE_URL}/v1/models", timeout=10) - if response.status_code == 200: - data = response.json() - for model in data.get("data", []): - if model["id"] == model_name: - return model.get("context_length", 4096) - except Exception: - pass - return 4096 # Default fallback - - -def get_model_aware_token_targets(model_name: str, model_size: str) -> Dict[str, int]: - """Get realistic token targets based on actual model capabilities.""" - context_length = get_model_context_length(model_name) - - # Calculate reasonable target based on model size + context - if model_size in ["1B", "3B"]: - target_tokens = min(512, context_length // 8) - elif model_size in ["4B", "7B"]: - target_tokens = min(1024, context_length // 6) - elif model_size in ["24B", "30B", "70B"]: - target_tokens = min(2048, context_length // 4) - else: - target_tokens = min(800, context_length // 6) - - return { - "target_tokens": target_tokens, - "min_tokens": target_tokens // 3, # Allow 33% variance - "context_length": context_length - } - - -def create_adaptive_trilogy_prompt(model_size: str, target_tokens: int) -> str: - """Create trilogy prompt adapted to model capabilities.""" - - base_plot = '''Here is the outline for fantasy trilogy "EMBERS OF THE FORGOTTEN": - -**MAIN CHARACTERS:** -1. Kaelen Veyra - The Exiled Flame Herald (32, war poet, controls Soulfire) -2. Sylra D'Tharn - The Shadow Warrior (28, assassin, uses Emotionweave) -3. Lord Morvath - The Unforgotten King (45, tragic villain with Grief-Crown) - -**TRILOGY STRUCTURE:** -- Book I: "Embers of the Forgotten" - The flame that remembers -- Book II: "The Lovers' Crucible" - The fire that doesn't burn -- Book III: "The Fire That Binds" - The flame that connects - -**THEMES:** Love as power not weakness, memory as healing, emotions as connection''' - - if model_size in ["1B", "3B"]: - task = f'''**YOUR TASK:** Write a 500-word opening scene of Book I featuring Kaelen's exile. -- Focus on Kaelen's emotional state after Lirien's death -- Use poetic, mythic language -- Target approximately {target_tokens} tokens -- End with him seeing Veyra (Valley of Faces) in the distance''' - - elif model_size in ["4B", "7B"]: - task = f'''**YOUR TASK:** Write the opening chapter of Book I: "The Poet Who Burned" -- Focus on Kaelen's exile from Celestine after Lirien's execution -- Include his emotional journey and Soulfire powers -- Use poetic, mythic language with deep inner rhythm -- Target approximately {target_tokens} tokens (1000-1500 words) -- End with his arrival at Veyra (Valley of Faces)''' - - else: # 24B, 30B, 70B - task = f'''**YOUR TASK:** Write the complete first chapter of Book I: "The Poet Who Burned" -- Focus on Kaelen's exile from Celestine after his beloved Lirien's execution -- Include his arrival at Veyra (Valley of Faces) with 30 lost masks -- Show his Soulfire powers and deep emotional development -- Use poetic, mythic language with deep inner rhythm -- Target approximately {target_tokens} tokens (2000+ words) -- Include dialogue and rich character development -- End with the mysterious mask whispering: "You were here - a thousand years ago"''' - - return f"{base_plot}\n\n{task}\n\nWrite the complete chapter now." - - -def make_chat_request(model_name: str, prompt: str, stream: bool = False, timeout: int = 120) -> str: - """Make chat completion request to server.""" - payload = { - "model": model_name, - "messages": [{"role": "user", "content": prompt}], - "stream": stream, - "temperature": 0.7 - } - - response = requests.post( - f"{SERVER_BASE_URL}/v1/chat/completions", - json=payload, - timeout=timeout, - stream=stream - ) - - if not response.ok: - raise RuntimeError(f"Request failed: {response.status_code} - {response.text}") - - if stream: - # Handle streaming response - content = "" - for line in response.iter_lines(decode_unicode=True): - if line.startswith("data: "): - data_str = line[6:] - if data_str.strip() == "[DONE]": - break - try: - data = json.loads(data_str) - delta = data.get("choices", [{}])[0].get("delta", {}).get("content", "") - content += delta - except json.JSONDecodeError: - continue - return content - else: - # Handle non-streaming response - data = response.json() - return data.get("choices", [{}])[0].get("message", {}).get("content", "") - - -def contains_end_tokens(text: str, model_name: str) -> List[str]: - """Check if text contains any End-Tokens for the given model.""" - model_family = get_model_family(model_name) - end_tokens = MODEL_END_TOKENS.get(model_family, MODEL_END_TOKENS["default"]) - - found_tokens = [] - for token in end_tokens: - if token in text: - found_tokens.append(token) - - return found_tokens - - -def estimate_token_count(text: str) -> int: - """Rough token count estimation (4 chars per token average).""" - return len(text) // 4 - - -def get_safe_models_lazy(): - """Lazy evaluation for parametrize to avoid import-time server calls.""" - try: - return get_safe_models_for_system() - except: - return [("test-model", "1B", 1)] - - -def pytest_generate_tests(metafunc): - """Dynamic test parametrization to avoid import-time server calls.""" - if "model_name" in metafunc.fixturenames: - try: - with MLXKnifeServerManager() as server: - models = get_safe_models_for_system() - metafunc.parametrize("model_name,size_str,ram_needed", models) - except Exception as e: - pytest.skip(f"Cannot set up server for testing: {e}") - - -@pytest.mark.server -@pytest.mark.timeout(300) # 5 minute timeout for large models -def test_non_streaming_end_tokens(model_name, size_str, ram_needed): - """ - Test Issue #20: Non-streaming mode should show End-Tokens (EXPECTED TO FAIL). - - This test validates that non-streaming responses contain visible End-Tokens, - proving the server-side filtering bug in generate_batch(). - - Expected result: FAIL (End-Tokens visible) - this confirms Issue #20. - """ - logger.info(f"🔍 Testing NON-STREAMING End-Tokens with {model_name} ({size_str}, {ram_needed}GB RAM)") - - with MLXKnifeServerManager() as server: - # Get model-specific token targets - token_specs = get_model_aware_token_targets(model_name, size_str) - logger.info(f"Token targets: {token_specs}") - - # Create adaptive prompt (no max_tokens - let model use natural stopping) - prompt = create_adaptive_trilogy_prompt(size_str, token_specs["target_tokens"]) - - logger.info("🚫 Testing NON-STREAMING mode (should show End-Tokens)...") - - response_content = make_chat_request(model_name, prompt, stream=False, timeout=300) - - # Basic validation - assert response_content.strip(), "Non-streaming returned empty response" - - # Token count validation - estimated_tokens = estimate_token_count(response_content) - logger.info(f"Non-streaming response: ~{estimated_tokens} tokens") - logger.info(f"Response ends with: '{response_content[-100:]}'" if len(response_content) > 100 else f"Full response end: '{response_content}'") - - # Should generate reasonable amount - min_expected = token_specs["min_tokens"] - assert estimated_tokens >= min_expected, \ - f"Non-streaming generated too few tokens: {estimated_tokens} < {min_expected}" - - # Issue #20 Check: Non-streaming SHOULD contain End-Tokens (this is the bug) - found_end_tokens = contains_end_tokens(response_content, model_name) - - if found_end_tokens: - logger.error(f"❌ CONFIRMED Issue #20: Non-streaming contains End-Tokens: {found_end_tokens}") - logger.error(f"Raw response end: {repr(response_content[-50:])}") - # This SHOULD fail - it confirms Issue #20 - assert False, f"Issue #20 CONFIRMED: Non-streaming shows End-Tokens {found_end_tokens}" - else: - logger.warning(f"⚠️ UNEXPECTED: Non-streaming clean (no End-Tokens found)") - logger.info(f"✅ Non-streaming mode unexpectedly passed (no Issue #20 detected)") - - -@pytest.mark.server -@pytest.mark.timeout(300) # 5 minute timeout for large models -def test_streaming_end_tokens(model_name, size_str, ram_needed): - """ - Test Issue #20: Streaming mode should filter End-Tokens (EXPECTED TO PASS). - - This test validates that streaming responses properly filter End-Tokens, - proving the streaming pipeline works correctly. - - Expected result: PASS (End-Tokens filtered) - this shows streaming works correctly. - """ - logger.info(f"🔍 Testing STREAMING End-Tokens with {model_name} ({size_str}, {ram_needed}GB RAM)") - - with MLXKnifeServerManager() as server: - # Get model-specific token targets - token_specs = get_model_aware_token_targets(model_name, size_str) - logger.info(f"Token targets: {token_specs}") - - # Create adaptive prompt (no max_tokens - let model use natural stopping) - prompt = create_adaptive_trilogy_prompt(size_str, token_specs["target_tokens"]) - - logger.info("✅ Testing STREAMING mode (should filter End-Tokens)...") - - response_content = make_chat_request(model_name, prompt, stream=True, timeout=300) - - # Basic validation - assert response_content.strip(), "Streaming returned empty response" - - # Token count validation - estimated_tokens = estimate_token_count(response_content) - logger.info(f"Streaming response: ~{estimated_tokens} tokens") - logger.info(f"Response ends with: '{response_content[-100:]}'" if len(response_content) > 100 else f"Full response end: '{response_content}'") - - # Should generate reasonable amount - min_expected = token_specs["min_tokens"] - assert estimated_tokens >= min_expected, \ - f"Streaming generated too few tokens: {estimated_tokens} < {min_expected}" - - # Issue #20 Check: Streaming should NOT contain End-Tokens (correct behavior) - found_end_tokens = contains_end_tokens(response_content, model_name) - - if found_end_tokens: - logger.error(f"❌ UNEXPECTED: Streaming contains End-Tokens: {found_end_tokens}") - logger.error(f"Raw response end: {repr(response_content[-50:])}") - assert False, f"Streaming unexpectedly shows End-Tokens {found_end_tokens}" - else: - logger.info(f"✅ Streaming mode correctly filtered End-Tokens") - - -@pytest.mark.server -@pytest.mark.timeout(600) # Longer timeout for comparison test -def test_end_token_consistency_comparison(model_name, size_str, ram_needed): - """ - Test Issue #20: Direct comparison of streaming vs non-streaming End-Token handling. - - This test runs both modes and compares their End-Token behavior to document - the exact differences for Issue #20 analysis. - - Expected pattern: - - Non-streaming: Contains End-Tokens (Issue #20 bug) - - Streaming: Clean responses (correct behavior) - """ - logger.info(f"🔍 COMPARISON TEST: {model_name} ({size_str}, {ram_needed}GB RAM)") - logger.info("="*80) - - with MLXKnifeServerManager() as server: - # Get model-specific token targets - token_specs = get_model_aware_token_targets(model_name, size_str) - - # Create adaptive prompt (no max_tokens) - prompt = create_adaptive_trilogy_prompt(size_str, token_specs["target_tokens"]) - - responses = {} - end_token_results = {} - - # Test both modes - for stream_mode in [False, True]: - mode_name = "streaming" if stream_mode else "non-streaming" - logger.info(f"\n📡 Testing {mode_name.upper()} mode...") - - response_content = make_chat_request(model_name, prompt, stream=stream_mode, timeout=300) - responses[stream_mode] = response_content - - # Check End-Tokens - found_end_tokens = contains_end_tokens(response_content, model_name) - end_token_results[stream_mode] = found_end_tokens - - estimated_tokens = estimate_token_count(response_content) - logger.info(f"{mode_name} response: ~{estimated_tokens} tokens") - logger.info(f"{mode_name} ends with: '{response_content[-80:]}'" if len(response_content) > 80 else f"Full: '{response_content}'") - - if found_end_tokens: - logger.error(f"❌ {mode_name} contains End-Tokens: {found_end_tokens}") - else: - logger.info(f"✅ {mode_name} clean (no End-Tokens)") - - # Issue #20 Pattern Analysis - logger.info(f"\n📊 ISSUE #20 ANALYSIS for {model_name}:") - logger.info("="*80) - - non_stream_tokens = end_token_results[False] - stream_tokens = end_token_results[True] - - logger.info(f"Non-streaming End-Tokens: {non_stream_tokens if non_stream_tokens else 'None'}") - logger.info(f"Streaming End-Tokens: {stream_tokens if stream_tokens else 'None'}") - - # Issue #20 pattern detection - if non_stream_tokens and not stream_tokens: - logger.error(f"🎯 ISSUE #20 CONFIRMED!") - logger.error(f" - Non-streaming shows End-Tokens: {non_stream_tokens}") - logger.error(f" - Streaming filters correctly: Clean") - issue_20_detected = True - elif not non_stream_tokens and not stream_tokens: - logger.warning(f"⚠️ Both modes clean - Issue #20 not detected") - issue_20_detected = False - elif non_stream_tokens and stream_tokens: - logger.error(f"🚨 Both modes show End-Tokens - different issue?") - issue_20_detected = False - else: - logger.warning(f"🤔 Unexpected pattern - investigate further") - issue_20_detected = False - - # This test is purely documentary - it doesn't fail, just reports findings - logger.info(f"\n📝 Issue #20 Status: {'CONFIRMED' if issue_20_detected else 'NOT DETECTED'}") - logger.info("="*80) - - -if __name__ == "__main__": - # Quick test run - with MLXKnifeServerManager() as server: - models = get_safe_models_for_system() - print(f"Found {len(models)} safe models for testing:") - for model, size, ram in models: - print(f" {model} ({size}, {ram}GB)") \ No newline at end of file diff --git a/tests/integration/test_health_checks.py b/tests/integration/test_health_checks.py deleted file mode 100644 index f64f682..0000000 --- a/tests/integration/test_health_checks.py +++ /dev/null @@ -1,240 +0,0 @@ -""" -High Priority Tests: Health Check Robustness - -Tests ensure reliable "postmortem" analysis of model integrity: -- Corruption detection (partial downloads, missing files, LFS pointers, etc.) -- Deterministic results (consistent healthy/broken status) -- No false positives or negatives -""" -import pytest -import subprocess -import json -import shutil -from pathlib import Path -from typing import Dict, Any - - -@pytest.mark.timeout(30) -@pytest.mark.usefixtures("temp_cache_dir") -class TestHealthCheckRobustness: - """Test health check reliability for various corruption scenarios.""" - - def test_healthy_model_detection(self, mlx_knife_process, mock_model_cache): - """Verify healthy models are correctly identified.""" - # Create a healthy model - model_dir = mock_model_cache("test-model", healthy=True) - - # Run health check - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=15) - return_code = proc.returncode - - # Should complete successfully - assert return_code == 0, f"Health check failed: {stderr}" - - # Should report healthy status (if any models exist) - # Note: The actual output format depends on implementation - assert "broken" not in stdout.lower() or "0 broken" in stdout.lower() - - def test_missing_snapshot_detection(self, mlx_knife_process, mock_model_cache): - """Health check must detect missing snapshots directory.""" - # Create model with missing snapshots - model_dir = mock_model_cache("test-model", healthy=False, corruption_type="missing_snapshot") - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=15) - - # Should complete (may return error code if broken models found) - assert proc.returncode is not None - - # Should detect the corruption - either report broken models or handle gracefully - # The key is that it shouldn't crash or hang - assert len(stdout) > 0 or len(stderr) > 0, "Health check produced no output" - - def test_lfs_pointer_detection(self, mlx_knife_process, mock_model_cache): - """Health check must detect LFS pointer files instead of actual weights.""" - model_dir = mock_model_cache("test-model", healthy=False, corruption_type="lfs_pointer") - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=15) - - # Should handle LFS pointers appropriately - assert proc.returncode is not None - - # Should either detect as broken or handle gracefully - output = stdout + stderr - assert len(output) > 0, "Health check produced no output for LFS pointer" - - def test_missing_config_detection(self, mlx_knife_process, mock_model_cache): - """Health check must detect missing config.json.""" - model_dir = mock_model_cache("test-model", healthy=False, corruption_type="missing_config") - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=15) - - assert proc.returncode is not None - - # Should detect missing config - output = stdout + stderr - assert len(output) > 0 - - def test_missing_tokenizer_detection(self, mlx_knife_process, mock_model_cache): - """Health check must detect missing tokenizer.json.""" - model_dir = mock_model_cache("test-model", healthy=False, corruption_type="missing_tokenizer") - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=15) - - assert proc.returncode is not None - output = stdout + stderr - assert len(output) > 0 - - def test_truncated_safetensors_detection(self, mlx_knife_process, mock_model_cache): - """Health check must detect corrupted/truncated safetensors files.""" - model_dir = mock_model_cache("test-model", healthy=False, corruption_type="truncated_safetensors") - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=15) - - assert proc.returncode is not None - output = stdout + stderr - assert len(output) > 0 - - def test_deterministic_results(self, mlx_knife_process, mock_model_cache): - """Health check results must be consistent across multiple runs.""" - # Create a healthy model - model_dir = mock_model_cache("test-model", healthy=True) - - results = [] - for i in range(3): - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=15) - results.append({ - "return_code": proc.returncode, - "stdout": stdout.strip(), - "stderr": stderr.strip() - }) - - # All runs should have the same return code - return_codes = [r["return_code"] for r in results] - assert all(rc == return_codes[0] for rc in return_codes), f"Inconsistent return codes: {return_codes}" - - # Output should be consistent (allowing for timestamps or minor variations) - stdout_outputs = [r["stdout"] for r in results] - # Basic consistency check - all should have similar length and key content - if stdout_outputs[0]: - for stdout in stdout_outputs[1:]: - # Allow some variation but outputs should be similar - assert abs(len(stdout) - len(stdout_outputs[0])) < 100, "Highly variable output lengths" - - def test_no_false_positives(self, mlx_knife_process, mock_model_cache): - """Healthy model must never be reported as broken.""" - # Create multiple healthy models - for i in range(3): - mock_model_cache(f"healthy-model-{i}", healthy=True) - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=15) - - # Should succeed - assert proc.returncode == 0, f"Health check failed on healthy models: {stderr}" - - # Should not report broken models (or report 0 broken) - if "broken" in stdout.lower(): - assert "0 broken" in stdout.lower(), f"False positive: {stdout}" - - def test_no_false_negatives_batch(self, mlx_knife_process, mock_model_cache): - """Broken models must be detected reliably.""" - # Create various corrupted models - corruption_types = [ - "missing_config", - "missing_tokenizer", - "lfs_pointer", - "truncated_safetensors" - ] - - for i, corruption in enumerate(corruption_types): - mock_model_cache(f"broken-model-{i}", healthy=False, corruption_type=corruption) - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=15) - - # Should complete (may have non-zero exit if broken models found) - assert proc.returncode is not None - - # Should produce output indicating broken models or handle them gracefully - output = stdout + stderr - assert len(output) > 0, "No output for batch of broken models" - - def test_mixed_healthy_broken_models(self, mlx_knife_process, mock_model_cache): - """Health check must correctly categorize mixed model states.""" - # Create mix of healthy and broken models - mock_model_cache("healthy-1", healthy=True) - mock_model_cache("broken-1", healthy=False, corruption_type="missing_config") - mock_model_cache("healthy-2", healthy=True) - mock_model_cache("broken-2", healthy=False, corruption_type="lfs_pointer") - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=15) - - assert proc.returncode is not None - output = stdout + stderr - assert len(output) > 0, "No output for mixed model states" - - # Should handle mixed states appropriately - # The exact format depends on implementation, but should not crash - - -@pytest.mark.timeout(15) -class TestHealthCheckPerformance: - """Test health check performance and reliability.""" - - def test_health_check_timeout_handling(self, mlx_knife_process, temp_cache_dir): - """Health check should complete within reasonable time.""" - # Create several models to check - for i in range(5): - cache_name = f"models--test--model-{i}" - model_dir = temp_cache_dir / cache_name / "snapshots" / "main" - model_dir.mkdir(parents=True, exist_ok=True) - - (model_dir / "config.json").write_text('{"model_type": "test"}') - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - (model_dir / "model.safetensors").write_bytes(b"fake_model_data" * 1000) - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=30) # Should complete within 30s - - assert proc.returncode is not None, "Health check hung" - - def test_health_check_empty_cache(self, mlx_knife_process, temp_cache_dir): - """Health check should handle empty cache gracefully.""" - # temp_cache_dir is empty - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=10) - - # Should complete successfully with empty cache - assert proc.returncode == 0, f"Failed on empty cache: {stderr}" - assert len(stdout) >= 0 # Some output is expected (even if just "no models") - - def test_health_check_large_cache(self, mlx_knife_process, temp_cache_dir): - """Health check should handle larger cache sizes.""" - # Create many model directories (simulating large cache) - for i in range(20): - cache_name = f"models--test--model-{i:02d}" - model_dir = temp_cache_dir / cache_name / "snapshots" / "main" - model_dir.mkdir(parents=True, exist_ok=True) - - # Create minimal valid model files - (model_dir / "config.json").write_text(f'{{"model_type": "test", "id": {i}}}') - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - (model_dir / "model.safetensors").write_bytes(b"fake_data" * 50) - - proc = mlx_knife_process(["health"]) - stdout, stderr = proc.communicate(timeout=45) # Allow more time for large cache - - assert proc.returncode is not None, "Health check hung on large cache" - - # Should produce reasonable output - output = stdout + stderr - assert len(output) > 0, "No output for large cache" \ No newline at end of file diff --git a/tests/integration/test_issue_14.py b/tests/integration/test_issue_14.py deleted file mode 100644 index 61666ea..0000000 --- a/tests/integration/test_issue_14.py +++ /dev/null @@ -1,433 +0,0 @@ -""" -Test for Issue #14: Interactive Chat Self-Conversation Bug - -This test ensures that models don't continue conversations autonomously -by generating "You:", "Human:", "Assistant:" markers after their response. - -This test is self-contained and manages its own MLX Knife server instance. -""" - -import logging -import re -import signal -import subprocess -import time -from typing import List, Tuple - -import psutil -import pytest -import requests - -logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') -logger = logging.getLogger(__name__) - -# Realistic RAM requirements for 4-bit quantized models (in GB) -# Based on actual testing on Apple Silicon Macs -MODEL_RAM_REQUIREMENTS = { - "0.5B": 1, "1B": 2, "3B": 4, "4B": 5, - "7B": 8, "8x7B": 16, "24B": 20, "30B": 24, - "70B": 40, "480B": 180 # MoE with overhead, needs 96GB+ -} - -# Self-conversation patterns to detect Issue #14 -SELF_CONVERSATION_PATTERNS = [ - r'\nYou:', - r'\nHuman:', - r'\nAssistant:', - r'\nUser:', - r'\n\nYou:', - r'\n\nHuman:', - r'\n\nAssistant:', - r'\n\nUser:', -] - -SERVER_BASE_URL = "http://localhost:8000" -SERVER_PORT = 8000 - - -def extract_model_size(model_name: str) -> str: - """Extract model size from model name.""" - # Match patterns like "30B", "8x7B", "480B", "0.5B", "3.2B", "Phi-3-mini" etc. - size_patterns = [ - r'(\d+(?:\.\d+)?(?:x\d+)?B)', # 30B, 0.5B, 3.2B, 8x7B, 480B - r'Phi-3-mini', # Special case: Phi-3-mini = ~4B - r'Qwen2\.5-(\d+(?:\.\d+)?)B', # Qwen2.5-0.5B - ] - - for pattern in size_patterns: - match = re.search(pattern, model_name) - if match: - if 'Phi-3-mini' in model_name: - return '4B' # Phi-3-mini is ~4B parameters - elif 'Qwen2.5' in model_name: - return f"{match.group(1)}B" # Extract from Qwen2.5-0.5B - else: - return match.group(1) - - return "unknown" - - -def get_available_models() -> List[str]: - """Get list of available models from MLX Knife server.""" - try: - response = requests.get(f"{SERVER_BASE_URL}/v1/models", timeout=10) - response.raise_for_status() - data = response.json() - return [model["id"] for model in data["data"]] - except Exception as e: - pytest.skip(f"Cannot connect to MLX Knife server: {e}") - - -def get_safe_models_for_system() -> List[Tuple[str, str, int]]: - """Get models that fit safely in available system RAM.""" - total_ram_gb = psutil.virtual_memory().total // (1024**3) - available_ram_gb = psutil.virtual_memory().available // (1024**3) - - # Safety margin: use max 80% of available RAM, keep 4GB free minimum - max_usable_gb = min(available_ram_gb * 0.8, total_ram_gb - 4) - - logger.info(f"System RAM: {total_ram_gb}GB total, {available_ram_gb}GB available") - logger.info(f"Safe limit for model testing: {max_usable_gb:.1f}GB") - - safe_models = [] - all_models = get_available_models() - - for model in all_models: - size_str = extract_model_size(model) - required_ram = MODEL_RAM_REQUIREMENTS.get(size_str, 999) - - if required_ram <= max_usable_gb: - safe_models.append((model, size_str, required_ram)) - logger.info(f"✅ {model} ({size_str}) - fits in {required_ram}GB") - else: - logger.warning(f"⏭️ Skipping {model} ({size_str}) - needs {required_ram}GB, have {max_usable_gb:.1f}GB") - - if not safe_models: - pytest.skip("No models fit in available system RAM") - - return safe_models - - -def has_self_conversation_markers(text: str) -> bool: - """Check if text contains self-conversation markers indicating Issue #14.""" - for pattern in SELF_CONVERSATION_PATTERNS: - if re.search(pattern, text): - return True - return False - - -def chat_completion_request(model_name: str, prompt: str, max_tokens: int = 150) -> str: - """Send chat completion request to MLX Knife server.""" - payload = { - "model": model_name, - "messages": [{"role": "user", "content": prompt}], - "max_tokens": max_tokens, - "stream": False - } - - try: - response = requests.post( - f"{SERVER_BASE_URL}/v1/chat/completions", - json=payload, - timeout=60 - ) - response.raise_for_status() - data = response.json() - return data["choices"][0]["message"]["content"] - except Exception as e: - pytest.fail(f"Chat completion failed for {model_name}: {e}") - - -@pytest.mark.server -def test_issue_14_self_conversation_regression_original(mlx_server, model_name: str, size_str: str, ram_needed: int): - """ - Test Issue #14: Ensure models don't continue conversations autonomously. - - This test verifies that models stop cleanly after their response without - generating additional conversation turns like "You:", "Human:", etc. - """ - logger.info(f"🦫 Testing Issue #14 with {model_name} ({size_str}, {ram_needed}GB)") - - # Use constrained prompt to encourage natural stopping - test_prompt = "Write a short story about a friendly dragon in exactly 50 words." - - start_time = time.time() - response = chat_completion_request(model_name, test_prompt, max_tokens=100) - duration = time.time() - start_time - - logger.info(f"⏱️ Response time: {duration:.2f}s") - logger.info(f"📝 Response preview: {response[:100]}...") - - # Check for Issue #14: self-conversation markers - if has_self_conversation_markers(response): - # Log the problematic response for debugging - logger.error(f"❌ Self-conversation detected in {model_name}:") - logger.error(f"Full response: {repr(response)}") - pytest.fail(f"Issue #14 regression: {model_name} shows self-conversation markers") - - logger.info(f"✅ {model_name}: No self-conversation detected - Issue #14 fix working!") - - -def find_existing_mlxk_servers() -> List[psutil.Process]: - """Find any existing MLX Knife server processes.""" - servers = [] - for proc in psutil.process_iter(['pid', 'name', 'cmdline']): - try: - if proc.info['cmdline'] and any('mlxk' in arg and 'server' in arg for arg in proc.info['cmdline']): - servers.append(proc) - except (psutil.NoSuchProcess, psutil.AccessDenied): - continue - return servers - - -def cleanup_zombie_servers(port: int): - """Clean up any zombie MLX Knife servers on the specified port.""" - logger.info(f"🧹 Checking for existing servers on port {port}") - - # Check for processes using the port - handle macOS permission issues - try: - connections = psutil.net_connections(kind='inet') - except (psutil.AccessDenied, PermissionError) as e: - logger.warning(f"⚠️ Cannot scan network connections (permission denied): {e}") - logger.info("🔧 Falling back to process-based cleanup only") - connections = [] - - for conn in connections: - if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN: - try: - proc = psutil.Process(conn.pid) - logger.warning(f"⚠️ Found process {proc.pid} listening on port {port}: {proc.cmdline()}") - - if 'mlxk' in ' '.join(proc.cmdline()) and 'server' in ' '.join(proc.cmdline()): - logger.info(f"🛑 Terminating existing MLX Knife server {proc.pid}") - proc.terminate() - try: - proc.wait(timeout=5) - logger.info(f"✅ Server {proc.pid} terminated gracefully") - except psutil.TimeoutExpired: - logger.warning(f"⚡ Force killing server {proc.pid}") - proc.kill() - proc.wait() - else: - logger.error(f"❌ Port {port} is occupied by non-MLX process {proc.pid}") - raise RuntimeError(f"Port {port} is busy with: {proc.cmdline()}") - - except (psutil.NoSuchProcess, psutil.AccessDenied): - continue - - # Also check for any MLX Knife server processes (even if not on our port) - existing_servers = find_existing_mlxk_servers() - for server in existing_servers: - logger.warning(f"⚠️ Found zombie MLX Knife server: {server.pid}") - try: - server.terminate() - server.wait(timeout=5) - logger.info(f"✅ Cleaned up zombie server {server.pid}") - except (psutil.TimeoutExpired, psutil.NoSuchProcess): - try: - server.kill() - logger.info(f"⚡ Force killed zombie server {server.pid}") - except psutil.NoSuchProcess: - pass - - -class MLXKnifeServerManager: - """Context manager for MLX Knife server lifecycle with zombie cleanup.""" - - def __init__(self, port: int = 8000): - self.port = port - self.process = None - self.base_url = f"http://localhost:{port}" - - def start_server(self) -> bool: - """Start MLX Knife server and wait for it to be ready.""" - try: - # First, clean up any zombies or port conflicts - cleanup_zombie_servers(self.port) - - # Check if server is already running (after cleanup) - if self.is_server_running(): - logger.info("🟢 MLX Knife server already running") - return True - - logger.info(f"🚀 Starting MLX Knife server on port {self.port}") - - # Start server process - use sys.executable to ensure same Python env - import sys - self.process = subprocess.Popen( - [sys.executable, "-m", "mlx_knife.cli", "server", "--port", str(self.port)], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True - ) - - logger.info(f"📋 Started process PID: {self.process.pid}") - - # Give it a moment to fail fast if there's an immediate error - time.sleep(1) - if self.process.poll() is not None: - stdout, stderr = self.process.communicate() - logger.error(f"❌ Server failed immediately:") - logger.error(f"stdout: {stdout}") - logger.error(f"stderr: {stderr}") - return False - - # Wait for server to be ready (max 30 seconds) - for _ in range(60): # 30 seconds, 0.5s intervals - if self.is_server_running(): - logger.info("✅ MLX Knife server is ready") - return True - time.sleep(0.5) - - # Timeout - get final output - stdout, stderr = "", "" - if self.process: - try: - if self.process.poll() is None: - stdout, stderr = self.process.communicate(timeout=2) - else: - stdout, stderr = self.process.communicate() - except subprocess.TimeoutExpired: - stdout, stderr = "timeout", "timeout" - - logger.error("❌ Server failed to start within timeout") - logger.error(f"Final stdout: {stdout}") - logger.error(f"Final stderr: {stderr}") - self.stop_server() - return False - - except Exception as e: - import traceback - logger.error(f"❌ Failed to start server: {e}") - logger.error(f"Full traceback: {traceback.format_exc()}") - self.stop_server() - return False - - def stop_server(self): - """Stop MLX Knife server if running.""" - if self.process: - logger.info("🛑 Stopping MLX Knife server") - self.process.terminate() - try: - self.process.wait(timeout=10) - except subprocess.TimeoutExpired: - logger.warning("⚠️ Server didn't stop gracefully, killing...") - self.process.kill() - self.process.wait() - self.process = None - - def is_server_running(self) -> bool: - """Check if server is running and healthy.""" - try: - response = requests.get(f"{self.base_url}/health", timeout=2) - return response.status_code == 200 - except: - return False - - def __enter__(self): - if not self.start_server(): - pytest.skip("Failed to start MLX Knife server") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.stop_server() - - -@pytest.fixture(scope="module") -def mlx_server(): - """Pytest fixture to manage MLX Knife server for all tests in module.""" - with MLXKnifeServerManager(SERVER_PORT) as server: - yield server - - -@pytest.mark.server -def test_server_health(mlx_server): - """Verify MLX Knife server is running and healthy.""" - assert mlx_server.is_server_running(), "MLX Knife server is not healthy" - logger.info("🟢 MLX Knife server is healthy") - - -@pytest.mark.server -def test_issue_14_self_conversation_regression(mlx_server, model_name: str, size_str: str, ram_needed: int): - """ - Test Issue #14: Ensure models don't continue conversations autonomously. - - This test verifies that models stop cleanly after their response without - generating additional conversation turns like "You:", "Human:", etc. - """ - logger.info(f"🦫 Testing Issue #14 with {model_name} ({size_str}, {ram_needed}GB)") - - # Use constrained prompt to encourage natural stopping - test_prompt = "Write a short story about a friendly dragon in exactly 50 words." - - start_time = time.time() - response = chat_completion_request(model_name, test_prompt, max_tokens=100) - duration = time.time() - start_time - - logger.info(f"⏱️ Response time: {duration:.2f}s") - logger.info(f"📝 Response preview: {response[:100]}...") - - # Check for Issue #14: self-conversation markers - if has_self_conversation_markers(response): - # Log the problematic response for debugging - logger.error(f"❌ Self-conversation detected in {model_name}:") - logger.error(f"Full response: {repr(response)}") - pytest.fail(f"Issue #14 regression: {model_name} shows self-conversation markers") - - logger.info(f"✅ {model_name}: No self-conversation detected - Issue #14 fix working!") - - -def get_safe_models_lazy(): - """Lazy evaluation for parametrize to avoid import-time server calls.""" - try: - return get_safe_models_for_system() - except: - # Fallback for when server isn't running yet - return [("test-model", "1B", 1)] - - -# Dynamic test generation at runtime instead of import time -def pytest_generate_tests(metafunc): - """Dynamic test parametrization to avoid import-time server calls.""" - if "model_name" in metafunc.fixturenames: - # Only get models when actually running tests, not during import - try: - with MLXKnifeServerManager() as server: - models = get_safe_models_for_system() - metafunc.parametrize("model_name,size_str,ram_needed", models) - except Exception as e: - pytest.skip(f"Cannot set up server for testing: {e}") - - -if __name__ == "__main__": - # Quick smoke test - start server first - print("🦫 MLX Knife Issue #14 Test - Smoke Test") - print("=" * 50) - - # Test server start directly without context manager - manager = MLXKnifeServerManager() - success = manager.start_server() - - print(f"🏁 Server start result: {success}") - - if success: - try: - models = get_safe_models_for_system() - print(f"\n📊 Safe models for this system: {len(models)}") - - total_ram = psutil.virtual_memory().total // (1024**3) - available_ram = psutil.virtual_memory().available // (1024**3) - print(f"💾 System RAM: {total_ram}GB total, {available_ram}GB available") - print() - - for model, size, ram in models: - print(f" 🎯 {model}") - print(f" └─ Size: {size}, RAM needed: {ram}GB") - - print(f"\n🚀 Ready to run: pytest tests/integration/test_issue_14.py -v") - - finally: - manager.stop_server() - - else: - print("💡 Check the logs above for server start failure details") \ No newline at end of file diff --git a/tests/integration/test_issue_15_16.py b/tests/integration/test_issue_15_16.py deleted file mode 100644 index e73c2a0..0000000 --- a/tests/integration/test_issue_15_16.py +++ /dev/null @@ -1,404 +0,0 @@ -""" -Test for Issues #15 & #16: Dynamic Model-Aware Token Limits - -Issue #15: Token-Limit vs Stop-Token Race Condition -- Models cut off by artificial token limits before natural stopping -- Solution: Context-aware token policies based on model capabilities - -Issue #16: Interactive vs Server Token Limit Policies -- Interactive mode should allow unlimited tokens for natural completion -- Server mode needs DoS protection with reasonable limits -- Solution: Different token policies per usage context - -This test is self-contained and manages its own MLX Knife server instance. -""" - -import json -import logging -import re -import signal -import subprocess -import tempfile -import time -from pathlib import Path -from typing import Dict, List, Tuple - -import psutil -import pytest -import requests - -logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') -logger = logging.getLogger(__name__) - -# Realistic RAM requirements for 4-bit quantized models (in GB) -MODEL_RAM_REQUIREMENTS = { - "0.5B": 1, "1B": 2, "3B": 4, "4B": 5, - "7B": 8, "8x7B": 16, "24B": 20, "30B": 24, - "70B": 40, "480B": 180 -} - -SERVER_BASE_URL = "http://localhost:8001" # Different port to avoid conflicts -SERVER_PORT = 8001 - - -def extract_model_size(model_name: str) -> str: - """Extract model size from model name.""" - # Match patterns like "30B", "8x7B", "480B", "0.5B", "3.2B", "Phi-3-mini" etc. - size_patterns = [ - r'(\d+x\d+B)', # MoE models like "8x7B" - r'(\d+\.?\d*B)', # Standard like "30B", "0.5B", "3.2B" - r'(mini|small|medium|large)', # Qualitative sizes - ] - - for pattern in size_patterns: - match = re.search(pattern, model_name, re.IGNORECASE) - if match: - size = match.group(1).lower() - # Map qualitative sizes to quantitative - if size == 'mini': - return '3B' # Phi-3-mini is ~4B params - elif size == 'small': - return '1B' - elif size == 'medium': - return '7B' - elif size == 'large': - return '30B' - return size.upper() - - return "3B" # Default fallback - - -def get_available_ram_gb() -> int: - """Get available system RAM in GB.""" - try: - return int(psutil.virtual_memory().available / (1024**3)) - except Exception: - return 8 # Conservative fallback - - -def get_suitable_models(available_models: List[str]) -> List[str]: - """Filter models based on available RAM.""" - available_ram = get_available_ram_gb() - logger.info(f"Available RAM: {available_ram}GB") - - suitable = [] - for model in available_models: - size = extract_model_size(model) - required_ram = MODEL_RAM_REQUIREMENTS.get(size, 8) - - if required_ram <= available_ram: - suitable.append(model) - logger.info(f"✓ {model} ({size}, {required_ram}GB) - Suitable") - else: - logger.info(f"✗ {model} ({size}, {required_ram}GB) - Too large") - - return suitable - - -def get_cached_models() -> List[str]: - """Get list of cached MLX models.""" - try: - result = subprocess.run( - ["mlxk", "list", "--framework", "mlx"], - capture_output=True, text=True, timeout=10 - ) - if result.returncode != 0: - return [] - - models = [] - for line in result.stdout.split('\n'): - line = line.strip() - if line and not line.startswith('MODEL') and not line.startswith('NAME'): - # Extract model name from table format - parts = line.split() - if len(parts) >= 1 and not parts[0] in ['MODEL', 'NAME']: - models.append(parts[0]) - - return models - except Exception as e: - logger.warning(f"Failed to get cached models: {e}") - return [] - - -def extract_context_length_from_model(model_name: str) -> int: - """Extract context length from a real model's config.""" - try: - result = subprocess.run( - ["mlxk", "show", model_name, "--config"], - capture_output=True, text=True, timeout=10 - ) - if result.returncode != 0: - return 4096 - - # Extract JSON from the output (it comes after "Config:") - config_text = result.stdout - - # Find the JSON part after "Config:" - config_start = config_text.find("Config:") - if config_start == -1: - return 4096 - - json_text = config_text[config_start + 7:].strip() # Skip "Config:" - - try: - config = json.loads(json_text) - context_keys = [ - "max_position_embeddings", - "n_positions", - "context_length", - "max_sequence_length", - "seq_len" - ] - - for key in context_keys: - if key in config: - return config[key] - - return 4096 - except json.JSONDecodeError: - return 4096 - - except Exception: - return 4096 - - -class MLXKnifeServer: - """Manages MLX Knife server lifecycle for testing.""" - - def __init__(self, port: int = SERVER_PORT): - self.port = port - self.process = None - self.base_url = f"http://localhost:{port}" - - def start(self) -> bool: - """Start the MLX Knife server.""" - try: - cmd = [ - "mlxk", "server", - "--host", "127.0.0.1", - "--port", str(self.port), - "--max-tokens", "1000", # Conservative default for testing - "--log-level", "warning" - ] - - self.process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True - ) - - # Wait for server to start - for attempt in range(30): - try: - response = requests.get(f"{self.base_url}/v1/models", timeout=2) - if response.status_code == 200: - logger.info(f"MLX Knife server started on port {self.port}") - return True - except requests.RequestException: - pass - - if self.process.poll() is not None: - logger.error("Server process died during startup") - return False - - time.sleep(1) - - logger.error("Server failed to start within timeout") - return False - - except Exception as e: - logger.error(f"Failed to start server: {e}") - return False - - def stop(self): - """Stop the MLX Knife server.""" - if self.process: - try: - # Try graceful shutdown first - self.process.terminate() - try: - self.process.wait(timeout=10) - except subprocess.TimeoutExpired: - # Force kill if not responding - self.process.kill() - self.process.wait(timeout=5) - except Exception as e: - logger.warning(f"Error stopping server: {e}") - finally: - self.process = None - - def chat_completion(self, model: str, messages: List[Dict], max_tokens: int = None) -> Dict: - """Send chat completion request.""" - payload = { - "model": model, - "messages": messages, - "temperature": 0.3, - "stream": False - } - if max_tokens: - payload["max_tokens"] = max_tokens - - response = requests.post( - f"{self.base_url}/v1/chat/completions", - json=payload, - timeout=60 - ) - response.raise_for_status() - return response.json() - - -@pytest.fixture(scope="module") -def mlx_server(): - """Provide MLX Knife server for the test session.""" - server = MLXKnifeServer() - - if not server.start(): - pytest.skip("Failed to start MLX Knife server") - - try: - yield server - finally: - server.stop() - - -@pytest.fixture(scope="module") -def available_models(): - """Get available models suitable for current system.""" - all_models = get_cached_models() - if not all_models: - pytest.skip("No MLX models found in cache") - - suitable = get_suitable_models(all_models) - if not suitable: - pytest.skip("No suitable models found for current RAM") - - return suitable - - -@pytest.mark.server -class TestIssue15TokenLimitVsStopTokenRace: - """Test Issue #15: Token-Limit vs Stop-Token Race Condition Resolution.""" - - def test_model_context_length_extraction(self, available_models): - """Test that we can extract context length from real models.""" - model = available_models[0] - context_length = extract_context_length_from_model(model) - - assert context_length >= 512, f"Context length too small for {model}: {context_length}" - assert context_length <= 1048576, f"Context length unrealistic for {model}: {context_length}" # 1M tokens max - - logger.info(f"Model {model} has context length: {context_length}") - - def test_realistic_token_limits_prevent_race_condition(self, mlx_server, available_models): - """Test that realistic token limits prevent race conditions.""" - model = available_models[0] - context_length = extract_context_length_from_model(model) - - # Request tokens close to but under the expected server limit (context/2) - server_limit = context_length // 2 - test_tokens = min(server_limit - 100, 500) # Conservative test - - messages = [{"role": "user", "content": "Write a short story about a robot."}] - - response = mlx_server.chat_completion(model, messages, max_tokens=test_tokens) - - assert "choices" in response - assert len(response["choices"]) > 0 - choice = response["choices"][0] - assert "message" in choice - assert "content" in choice["message"] - - content = choice["message"]["content"] - assert len(content) > 0, "No content generated" - - # The key test: model should generate reasonable content within limits - # without being cut off mid-sentence due to race conditions - logger.info(f"Generated {len(content)} characters with {test_tokens} token limit") - - -@pytest.mark.server -class TestIssue16InteractiveVsServerTokenPolicies: - """Test Issue #16: Interactive vs Server Token Limit Policies Resolution.""" - - def test_server_mode_uses_dos_protection_limits(self, mlx_server, available_models): - """Test that server mode uses DoS protection (context/2).""" - model = available_models[0] - context_length = extract_context_length_from_model(model) - server_limit = context_length // 2 - - # Request more tokens than server limit should allow, but not too excessive for testing - excessive_tokens = min(server_limit + 200, 800) # Keep reasonable for testing - - messages = [{"role": "user", "content": "Write a brief summary of machine learning."}] - - # This should work without errors - the server should internally - # limit tokens to the DoS protection limit - response = mlx_server.chat_completion(model, messages, max_tokens=excessive_tokens) - - assert "choices" in response - assert len(response["choices"]) > 0 - choice = response["choices"][0] - assert "message" in choice - assert "content" in choice["message"] - - content = choice["message"]["content"] - assert len(content) > 0 - - # The response should be successful, proving the server handles - # excessive token requests gracefully - logger.info(f"Server handled excessive token request ({excessive_tokens}) gracefully") - logger.info(f"Model context: {context_length}, Server limit: {server_limit}, Generated content length: {len(content)}") - - def test_server_honors_reasonable_token_requests(self, mlx_server, available_models): - """Test that server honors reasonable token requests.""" - model = available_models[0] - context_length = extract_context_length_from_model(model) - server_limit = context_length // 2 - - # Request reasonable number of tokens (well under limit) - reasonable_tokens = min(server_limit // 4, 200) - - messages = [{"role": "user", "content": "Say hello."}] - - response = mlx_server.chat_completion(model, messages, max_tokens=reasonable_tokens) - - assert "choices" in response - assert len(response["choices"]) > 0 - choice = response["choices"][0] - assert "message" in choice - assert "content" in choice["message"] - - content = choice["message"]["content"] - assert len(content) > 0 - assert "hello" in content.lower() or "hi" in content.lower() - - logger.info(f"Server honored reasonable token request ({reasonable_tokens})") - - def test_model_capabilities_vs_hardcoded_limits(self, available_models): - """Test that models with different context lengths get appropriate limits.""" - if len(available_models) < 2: - pytest.skip("Need multiple models to compare context lengths") - - model_contexts = [] - for model in available_models[:3]: # Test up to 3 models - context_length = extract_context_length_from_model(model) - model_contexts.append((model, context_length)) - - # Verify that different models have different context lengths - # (or at least our system recognizes their individual capabilities) - contexts = [ctx for _, ctx in model_contexts] - - # At minimum, verify context extraction worked - for model, context in model_contexts: - assert context >= 1024, f"Model {model} context too small: {context}" - logger.info(f"Model {model}: {context} tokens context") - - # The key insight: No hardcoded 500/2000 token limits! - # Each model gets limits based on its actual capabilities - for model, context in model_contexts: - server_limit = context // 2 - # Server limits should be much higher than old hardcoded limits - # for models with large context windows - if context >= 4096: - assert server_limit >= 2048, f"Model {model} should have server limit >= 2048, got {server_limit}" \ No newline at end of file diff --git a/tests/integration/test_lock_cleanup_bug.py b/tests/integration/test_lock_cleanup_bug.py deleted file mode 100644 index 73be0a2..0000000 --- a/tests/integration/test_lock_cleanup_bug.py +++ /dev/null @@ -1,99 +0,0 @@ -#!/usr/bin/env python3 -""" -Integration test for lock cleanup bug. -This test reproduces the real bug found in Issue #24. -""" - -from pathlib import Path -import pytest - -from mlx_knife.cache_utils import _cleanup_model_locks - - -@pytest.mark.usefixtures("temp_cache_dir") -class TestLockCleanupBug: - """Integration tests for lock cleanup functionality.""" - - def test_lock_cleanup_path_bug(self, temp_cache_dir, patch_model_cache): - """Test that reproduces the lock cleanup path bug. - - The bug: _cleanup_model_locks uses MODEL_CACHE.parent instead of MODEL_CACHE, - causing it to look for locks in the wrong directory. - - HF Cache structure: - cache_root/ - └── hub/ ← MODEL_CACHE - ├── .locks/ ← Correct location - └── models--name/ - - Bug: looks in cache_root/.locks/ instead of cache_root/hub/.locks/ - """ - hub_cache = temp_cache_dir / "hub" - - with patch_model_cache(hub_cache): - # Create test model structure - model_name = "test-org/broken-model" - cache_dir_name = "models--test-org--broken-model" - - # Create model directory (not needed for lock cleanup, but realistic) - model_dir = hub_cache / cache_dir_name - model_dir.mkdir() - - # Create lock files in CORRECT location: hub/.locks/ - locks_dir = hub_cache / ".locks" / cache_dir_name - locks_dir.mkdir(parents=True) - (locks_dir / "download.lock").touch() - (locks_dir / "process.lock").touch() - (locks_dir / "huggingface.lock").write_text("PID:12345") - (locks_dir / "another.lock").touch() - - # Verify setup - assert locks_dir.exists(), "Lock directory should exist" - lock_files = list(locks_dir.iterdir()) - assert len(lock_files) == 4, f"Should have 4 lock files, got {len(lock_files)}" - - # This should clean up the locks, but currently fails due to path bug - _cleanup_model_locks(model_name, force=True) - - # BUG: Lock directory still exists because function looks in wrong path - # This assertion will FAIL until the bug is fixed - assert not locks_dir.exists(), ( - f"❌ BUG REPRODUCED: Lock directory still exists at {locks_dir}. " - f"The _cleanup_model_locks function is looking in the wrong path." - ) - - def test_lock_cleanup_empty_directory(self, temp_cache_dir, patch_model_cache): - """Test that _cleanup_model_locks handles empty lock directories gracefully.""" - hub_cache = temp_cache_dir / "hub" - - with patch_model_cache(hub_cache): - model_name = "test-org/empty-locks" - cache_dir_name = "models--test-org--empty-locks" - - # Create empty lock directory - locks_dir = hub_cache / ".locks" / cache_dir_name - locks_dir.mkdir(parents=True) - - assert locks_dir.exists() - assert len(list(locks_dir.iterdir())) == 0 - - # Should handle empty directory gracefully (no-op) - _cleanup_model_locks(model_name, force=True) - - # Empty directory should still exist (function returns early) - # This will also fail due to path bug, but for different reason - - def test_lock_cleanup_nonexistent_locks(self, temp_cache_dir, patch_model_cache): - """Test that _cleanup_model_locks handles missing lock directories gracefully.""" - hub_cache = temp_cache_dir / "hub" - - with patch_model_cache(hub_cache): - model_name = "test-org/no-locks" - - # Don't create any lock directory - - # Should handle gracefully (no-op) - _cleanup_model_locks(model_name, force=True) - - # This should pass (no error thrown) - assert True, "Function should handle missing lock directories gracefully" \ No newline at end of file diff --git a/tests/integration/test_process_lifecycle.py b/tests/integration/test_process_lifecycle.py deleted file mode 100644 index 790cb55..0000000 --- a/tests/integration/test_process_lifecycle.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -High Priority Tests: Process Lifecycle Management - -Tests ensure clean process handling and resource management: -- No zombie processes after normal exit or interruption -- Proper signal handling (SIGTERM, SIGKILL, SIGINT) -- Resource management (file handles, sockets, memory) -- Clean streaming interruption -""" -import pytest -import subprocess -import signal -import time -import psutil -import os -from pathlib import Path - - -@pytest.mark.timeout(30) -class TestProcessLifecycle: - """Test process lifecycle management and cleanup.""" - - def test_no_zombie_processes_normal_exit(self, mlx_knife_process, process_monitor): - """Ensure normal exit leaves no background processes.""" - # Start a simple command that should exit cleanly - proc = mlx_knife_process(["list"]) - main_pid = proc.pid - - # Track child processes before termination - children_before = process_monitor["get_process_tree"](main_pid) - - # Wait for normal completion - return_code = proc.wait(timeout=10) - - # Verify main process exited normally - assert return_code == 0 - - # Verify no child processes remain - assert process_monitor["wait_for_cleanup"](main_pid, timeout=5) - - # Double-check: no processes should be running - for child in children_before: - assert not child.is_running(), f"Zombie process detected: PID {child.pid}" - - def test_no_zombie_processes_sigint(self, mlx_knife_process, process_monitor, temp_cache_dir): - """Ensure SIGINT (Ctrl+C) kills all child processes.""" - # Create a mock model for a longer-running command - mock_model_cache = self._create_simple_mock_model(temp_cache_dir) - - # Start a command that would run longer (health check) - proc = mlx_knife_process(["health"]) - main_pid = proc.pid - - # Give it a moment to start and potentially spawn children - time.sleep(0.5) - - # Track child processes - children_before = process_monitor["get_process_tree"](main_pid) - - # Send SIGINT (Ctrl+C equivalent) - proc.send_signal(signal.SIGINT) - - # Wait for termination - try: - return_code = proc.wait(timeout=10) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Process did not respond to SIGINT within timeout") - - # Verify process was interrupted - assert return_code != 0 # Should not exit normally - - # Verify all child processes are cleaned up - assert process_monitor["wait_for_cleanup"](main_pid, timeout=5) - - for child in children_before: - assert not child.is_running(), f"Child process survived SIGINT: PID {child.pid}" - - def test_no_zombie_processes_sigterm(self, mlx_knife_process, process_monitor, temp_cache_dir): - """Ensure SIGTERM leads to graceful shutdown.""" - # Create a mock model - mock_model_cache = self._create_simple_mock_model(temp_cache_dir) - - # Start health check command - proc = mlx_knife_process(["health"]) - main_pid = proc.pid - - time.sleep(0.5) - children_before = process_monitor["get_process_tree"](main_pid) - - # Send SIGTERM - proc.send_signal(signal.SIGTERM) - - try: - return_code = proc.wait(timeout=10) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Process did not respond to SIGTERM within timeout") - - # Verify graceful shutdown - assert return_code != 0 # Interrupted - - # Verify cleanup - assert process_monitor["wait_for_cleanup"](main_pid, timeout=5) - - for child in children_before: - assert not child.is_running(), f"Child process survived SIGTERM: PID {child.pid}" - - def test_process_cleanup_after_sigkill(self, mlx_knife_process, process_monitor, temp_cache_dir): - """Test cleanup after SIGKILL (should kill immediately).""" - mock_model_cache = self._create_simple_mock_model(temp_cache_dir) - - proc = mlx_knife_process(["health"]) - main_pid = proc.pid - - time.sleep(0.5) - children_before = process_monitor["get_process_tree"](main_pid) - - # SIGKILL should kill immediately - proc.send_signal(signal.SIGKILL) - - try: - return_code = proc.wait(timeout=5) - except subprocess.TimeoutExpired: - pytest.fail("Process did not die from SIGKILL") - - # SIGKILL has specific return code - assert return_code == -signal.SIGKILL - - # Child processes should be cleaned up by OS - assert process_monitor["wait_for_cleanup"](main_pid, timeout=5) - - def test_download_worker_cleanup(self, mlx_knife_process, process_monitor, temp_cache_dir, patch_model_cache): - """Ensure download workers don't become zombies - uses isolated cache.""" - # This test simulates download interruption with Phi-3-mini in ISOLATED cache - # Any broken download will be auto-cleaned, user cache stays pristine - - with patch_model_cache(temp_cache_dir / "hub"): - proc = mlx_knife_process(["pull", "mlx-community/Phi-3-mini-4k-instruct-4bit", "--no-progress"]) - main_pid = proc.pid - - # Let download start - time.sleep(2.0) - - children_before = process_monitor["get_process_tree"](main_pid) - - # Interrupt the download - proc.send_signal(signal.SIGINT) - - try: - return_code = proc.wait(timeout=15) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Download process did not respond to interruption") - - # Verify cleanup - this is critical for download workers - assert process_monitor["wait_for_cleanup"](main_pid, timeout=10) - - for child in children_before: - if child.is_running(): - # Give more details about surviving process - try: - cmd = " ".join(child.cmdline()) - pytest.fail(f"Download worker survived: PID {child.pid}, CMD: {cmd}") - except (psutil.NoSuchProcess, psutil.AccessDenied): - pass # Process died while we were checking - - print("✓ Download interrupt test completed - any broken Phi-3-mini in isolated cache will be auto-cleaned") - - def test_streaming_interruption_cleanup(self, mlx_knife_process, process_monitor, temp_cache_dir, patch_model_cache): - """Test clean cancellation of token generation streaming - uses tiny test model for isolation.""" - # Use tiny-random-gpt2 for streaming tests to avoid dependencies on user cache - test_model = "hf-internal-testing/tiny-random-gpt2" - test_prompt = "Write a long story about a cat and a dog." - - with patch_model_cache(temp_cache_dir / "hub"): - # First download the model for this isolated test - from mlx_knife.hf_download import pull_model - from unittest.mock import patch - - with patch('builtins.input', return_value='y'): - pull_model(test_model) - - proc = mlx_knife_process(["run", test_model, test_prompt]) - - # Let it start generating, then interrupt - time.sleep(2) # Give it time to start - - # Send SIGINT (Ctrl+C) to interrupt gracefully - proc.send_signal(signal.SIGINT) - - try: - stdout, stderr = proc.communicate(timeout=10) - # Should terminate gracefully - assert proc.returncode is not None, "Process didn't terminate after SIGINT" - except subprocess.TimeoutExpired: - # If it doesn't respond to SIGINT, force kill - proc.kill() - stdout, stderr = proc.communicate() - pytest.fail("Process didn't respond to SIGINT - cleanup may have failed") - - # Check that we got some output before interruption - assert len(stdout) >= 0, "Process should handle interruption gracefully" - - print("✓ Streaming interrupt test completed - test model in isolated cache will be auto-cleaned") - - def test_file_handle_management(self, mlx_knife_process, temp_cache_dir): - """Verify no file handle leaks after process termination.""" - # Get initial file descriptor count - initial_fds = len(os.listdir("/proc/self/fd")) if os.path.exists("/proc/self/fd") else 0 - - mock_model_cache = self._create_simple_mock_model(temp_cache_dir) - - # Run several operations - for _ in range(3): - proc = mlx_knife_process(["list"]) - proc.wait(timeout=10) - - # Check file descriptors haven't grown significantly - if os.path.exists("/proc/self/fd"): - final_fds = len(os.listdir("/proc/self/fd")) - # Allow some tolerance for test framework overhead - assert final_fds <= initial_fds + 5, f"Potential file handle leak: {initial_fds} -> {final_fds}" - - def _create_simple_mock_model(self, temp_cache_dir: Path) -> Path: - """Helper to create a simple mock model for testing.""" - cache_name = "models--test--model" - model_dir = temp_cache_dir / cache_name / "snapshots" / "main" - model_dir.mkdir(parents=True, exist_ok=True) - - (model_dir / "config.json").write_text('{"model_type": "test"}') - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - (model_dir / "model.safetensors").write_bytes(b"fake_model_data" * 100) - - return model_dir - - -@pytest.mark.timeout(60) -class TestResourceManagement: - """Test resource management and memory cleanup.""" - - def test_memory_cleanup_after_operations(self, mlx_knife_process, temp_cache_dir): - """Verify memory is properly released after operations.""" - # This is a basic test - real memory testing would require more sophisticated tools - mock_model_cache = self._create_simple_mock_model(temp_cache_dir) - - # Run operations and ensure they complete without hanging - operations = [ - ["list"], - ["health"], - ["show", "test/model"] # This should gracefully handle non-existent model - ] - - for op in operations: - proc = mlx_knife_process(op) - return_code = proc.wait(timeout=15) - # Operations should complete (may fail, but should not hang) - assert return_code is not None, f"Operation {op} hung" - - def _create_simple_mock_model(self, temp_cache_dir: Path) -> Path: - """Helper to create a simple mock model for testing.""" - cache_name = "models--test--model" - model_dir = temp_cache_dir / cache_name / "snapshots" / "main" - model_dir.mkdir(parents=True, exist_ok=True) - - (model_dir / "config.json").write_text('{"model_type": "test"}') - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - (model_dir / "model.safetensors").write_bytes(b"fake_model_data" * 100) - - return model_dir \ No newline at end of file diff --git a/tests/integration/test_real_model_lifecycle.py b/tests/integration/test_real_model_lifecycle.py deleted file mode 100644 index 1f88bd1..0000000 --- a/tests/integration/test_real_model_lifecycle.py +++ /dev/null @@ -1,349 +0,0 @@ -""" -Integration tests for real model lifecycle using tiny real models. - -This replaces heavily mocked tests with comprehensive integration tests using -hf-internal-testing/tiny-random-gpt2 (112k params, ~500KB) to test: -- Real file system operations -- Real path resolution logic -- Real framework detection -- Real lock cleanup (our main bug from Issue #23) -- End-to-end model lifecycle: pull → list → show → rm - -Strategy: ONE pull for all tests to be efficient, then comprehensive testing -of the full pipeline with real files and directories. -""" -import pytest -import os -import shutil -from pathlib import Path -from unittest.mock import patch -from mlx_knife.hf_download import pull_model -from mlx_knife.cache_utils import ( - list_models, show_model, rm_model, find_matching_models, - resolve_single_model, is_model_healthy, detect_framework, - hf_to_cache_dir, MODEL_CACHE -) - - -class TestRealModelLifecycle: - """Test complete model lifecycle with real tiny model in isolated cache.""" - - TEST_MODEL = "hf-internal-testing/tiny-random-gpt2" - EXPECTED_SIZE_RANGE = (10_000_000, 15_000_000) # ~12.5MB expected - - @staticmethod - def get_current_model_cache(): - """Get the current model cache path (resolves HF_HOME dynamically).""" - cache_root = Path(os.environ.get("HF_HOME", Path.home() / ".cache/huggingface")) - return cache_root / "hub" - - @pytest.fixture(scope="class", autouse=True) - def setup_isolated_model(self, class_temp_cache_dir): - """Download test model to isolated cache before all tests in this class.""" - print(f"\n=== Downloading {self.TEST_MODEL} to isolated test cache ===") - print(f"Test cache location: {class_temp_cache_dir}") - - # Patch MODEL_CACHE to point to our isolated cache - from mlx_knife import cache_utils - original_model_cache = cache_utils.MODEL_CACHE - cache_utils.MODEL_CACHE = class_temp_cache_dir / "hub" - - try: - # Pull the tiny test model (patch input to auto-confirm) - with patch('builtins.input', return_value='y'): - pull_model(self.TEST_MODEL) - - # Verify model exists in isolated cache - cache_dir_name = hf_to_cache_dir(self.TEST_MODEL) - model_cache_path = cache_utils.MODEL_CACHE / cache_dir_name - - if not model_cache_path.exists(): - print(f"HF_HOME: {os.environ.get('HF_HOME', 'not set')}") - print(f"Expected cache path: {model_cache_path}") - print(f"Cache contents: {list(cache_utils.MODEL_CACHE.iterdir()) if cache_utils.MODEL_CACHE.exists() else 'does not exist'}") - pytest.fail(f"Model download failed - cache directory not found: {model_cache_path}") - - print(f"✅ Successfully downloaded {self.TEST_MODEL}") - print(f"📁 Model cached at: {model_cache_path}") - print(f"🔒 Using isolated test cache (user cache untouched)") - - # Fixture runs for all tests in this class - yield - - finally: - # Restore original MODEL_CACHE - cache_utils.MODEL_CACHE = original_model_cache - print(f"\n=== Test cache cleanup and MODEL_CACHE restored ===") - - def test_01_model_downloaded_successfully(self): - """Test that real model download created proper file structure.""" - from mlx_knife import cache_utils - cache_dir_name = hf_to_cache_dir(self.TEST_MODEL) - model_cache_path = cache_utils.MODEL_CACHE / cache_dir_name - - # Verify top-level structure exists - assert model_cache_path.exists(), f"Model cache directory missing: {model_cache_path}" - assert (model_cache_path / "snapshots").exists(), "Snapshots directory missing" - assert (model_cache_path / "refs").exists(), "Refs directory missing" - - # Verify refs/main exists and points to a hash - refs_main = model_cache_path / "refs" / "main" - assert refs_main.exists(), "refs/main missing" - - commit_hash = refs_main.read_text().strip() - assert len(commit_hash) >= 8, f"Invalid commit hash: {commit_hash}" - - # Verify snapshot directory exists for the hash - snapshot_dir = model_cache_path / "snapshots" / commit_hash - assert snapshot_dir.exists(), f"Snapshot directory missing: {snapshot_dir}" - - # Verify essential model files exist - config_json = snapshot_dir / "config.json" - assert config_json.exists(), "config.json missing" - - # Check file size is reasonable (tiny model should be ~500KB total) - total_size = sum(f.stat().st_size for f in snapshot_dir.rglob("*") if f.is_file()) - assert self.EXPECTED_SIZE_RANGE[0] <= total_size <= self.EXPECTED_SIZE_RANGE[1], \ - f"Model size {total_size} outside expected range {self.EXPECTED_SIZE_RANGE}" - - print(f"✓ Real model downloaded: {total_size:,} bytes in {snapshot_dir}") - - def test_02_list_shows_downloaded_model(self): - """Test that list command shows our real downloaded model.""" - # Use list with health check to verify model is detected and healthy - import io - import contextlib - - stdout_capture = io.StringIO() - with contextlib.redirect_stdout(stdout_capture): - list_models(show_all=True, show_health=True) # Show all models with health status - - output = stdout_capture.getvalue() - - # Verify our test model appears in the output - assert self.TEST_MODEL in output or "tiny-random-gpt2" in output, \ - f"Test model not found in list output: {output}" - - print(f"✓ Model appears in list output with health status") - - def test_03_show_detects_real_framework(self): - """Test that show command detects framework for real model.""" - import io - import contextlib - - stdout_capture = io.StringIO() - with contextlib.redirect_stdout(stdout_capture): - show_model(self.TEST_MODEL) - - output = stdout_capture.getvalue() - - # Verify show command produced output about our model - assert self.TEST_MODEL in output or "tiny-random-gpt2" in output, \ - f"Model not found in show output: {output}" - - # Should have framework detection - assert "Framework:" in output, f"Framework detection missing: {output}" - - # Should have health status - assert "Health:" in output, f"Health status missing: {output}" - - # Should show size information - assert any(keyword in output.lower() for keyword in ["size", "gb", "mb", "kb"]), \ - f"Size information missing: {output}" - - print(f"✓ Show command detected framework and health for real model") - - def test_04_find_matching_works_with_real_model(self): - """Test that fuzzy matching works with real model.""" - # Test exact match - exact_matches = find_matching_models(self.TEST_MODEL) - assert len(exact_matches) >= 1, f"Exact match failed for {self.TEST_MODEL}" - - # Test partial match - partial_matches = find_matching_models("tiny-random") - assert len(partial_matches) >= 1, f"Partial match failed for 'tiny-random'" - - # Verify our model is in the matches - model_names = [match[1] for match in partial_matches] - assert any(self.TEST_MODEL in name for name in model_names), \ - f"Test model not found in partial matches: {model_names}" - - print(f"✓ Fuzzy matching works: {len(partial_matches)} matches for 'tiny-random'") - - def test_05_resolve_real_model_paths(self): - """Test that path resolution works with real model.""" - # Test exact model resolution - model_path, resolved_name, commit_hash = resolve_single_model(self.TEST_MODEL) - - assert model_path is not None, f"Failed to resolve model path for {self.TEST_MODEL}" - assert model_path.exists(), f"Resolved path does not exist: {model_path}" - assert resolved_name == self.TEST_MODEL, f"Name resolution incorrect: {resolved_name}" - assert commit_hash is not None, f"Commit hash not resolved" - assert len(commit_hash) >= 8, f"Invalid commit hash: {commit_hash}" - - # Test fuzzy resolution - fuzzy_path, fuzzy_name, fuzzy_hash = resolve_single_model("tiny-random") - - assert fuzzy_path is not None, f"Fuzzy resolution failed for 'tiny-random'" - assert fuzzy_path.exists(), f"Fuzzy resolved path does not exist: {fuzzy_path}" - - # Both should resolve to same model - assert fuzzy_path == model_path, f"Fuzzy and exact paths differ: {fuzzy_path} vs {model_path}" - - print(f"✓ Path resolution works: {model_path}") - - def test_06_health_check_on_real_model(self): - """Test health checking on real model files.""" - # Resolve model to get path - model_path, _, _ = resolve_single_model(self.TEST_MODEL) - assert model_path is not None, "Model resolution failed" - - # Test health check - is_healthy = is_model_healthy(self.TEST_MODEL) - - # Real downloaded model should be healthy - assert is_healthy, f"Real model reported as unhealthy: {self.TEST_MODEL}" - - # Test framework detection - framework = detect_framework(model_path, self.TEST_MODEL) - assert framework is not None, f"Framework detection failed for real model" - assert isinstance(framework, str), f"Framework should be string: {framework}" - assert len(framework) > 0, f"Empty framework detected: {framework}" - - print(f"✓ Health check passed, framework: {framework}") - - # Also test using show command for health verification - import io - import contextlib - - stdout_capture = io.StringIO() - with contextlib.redirect_stdout(stdout_capture): - show_model(self.TEST_MODEL) - - show_output = stdout_capture.getvalue() - assert "Health:" in show_output, f"Health status missing in show output: {show_output}" - - print(f"✓ Show command also reports health status correctly") - - def test_07_rm_cleans_locks_and_model(self): - """Test that rm command cleans both model AND locks (Issue #23 fix).""" - # Verify model exists before deletion - model_path, _, _ = resolve_single_model(self.TEST_MODEL) - assert model_path is not None, "Model should exist before deletion" - assert model_path.exists(), f"Model path should exist before deletion: {model_path}" - - # Get model cache directory and expected locks directory - from mlx_knife import cache_utils - cache_dir_name = hf_to_cache_dir(self.TEST_MODEL) - model_cache_path = cache_utils.MODEL_CACHE / cache_dir_name - locks_dir = cache_utils.MODEL_CACHE / ".locks" / cache_dir_name - - # Create some test lock files if they don't exist - if not locks_dir.exists(): - locks_dir.mkdir(parents=True) - (locks_dir / "test.lock").touch() - - lock_files_before = list(locks_dir.iterdir()) if locks_dir.exists() else [] - - print(f"Before deletion:") - print(f" Model cache: {model_cache_path.exists()}") - print(f" Locks dir: {locks_dir.exists()}") - print(f" Lock files: {len(lock_files_before)}") - - # Remove model with force=True (no prompts) - rm_model(self.TEST_MODEL, force=True) - - # Verify BOTH model and locks are cleaned up - model_exists_after = model_cache_path.exists() - locks_exist_after = locks_dir.exists() - - print(f"After deletion:") - print(f" Model cache: {model_exists_after}") - print(f" Locks dir: {locks_exist_after}") - - # Issue #23 fix: Both should be deleted - assert not model_exists_after, f"Model cache should be deleted: {model_cache_path}" - assert not locks_exist_after, f"Locks directory should be deleted: {locks_dir}" - - print(f"✓ rm command cleaned both model and locks (Issue #23 fix verified)") - - def test_08_model_completely_removed(self): - """Test end-to-end verification that model is completely gone.""" - # Verify model no longer appears in list - import io - import contextlib - - stdout_capture = io.StringIO() - with contextlib.redirect_stdout(stdout_capture): - list_models(show_all=True) # Show all models, not just MLX ones - - output = stdout_capture.getvalue() - - # Our test model should NOT appear in output anymore - assert self.TEST_MODEL not in output, \ - f"Model still appears in list after deletion: {output}" - assert "tiny-random-gpt2" not in output, \ - f"Model name still appears in list after deletion: {output}" - - # Verify resolution fails - model_path, resolved_name, commit_hash = resolve_single_model(self.TEST_MODEL) - assert model_path is None, f"Model path should be None after deletion: {model_path}" - assert resolved_name is None, f"Resolved name should be None after deletion: {resolved_name}" - - # Verify fuzzy matching also fails - matches = find_matching_models("tiny-random") - model_names = [match[1] for match in matches] if matches else [] - assert not any(self.TEST_MODEL in name for name in model_names), \ - f"Model still found in fuzzy matches: {model_names}" - - print(f"✓ Model completely removed from cache and indexes") - - -class TestIntegrationTestSelfCheck: - """Meta-test: Verify integration tests are working properly.""" - - def test_integration_test_downloads_real_files(self): - """Verify this integration test actually downloaded real files.""" - # This test runs after TestRealModelLifecycle, so model should be cleaned up - # But we can verify the test ran by checking if we have network access - # and that the model we tried to download is a real HF model - - model = TestRealModelLifecycle.TEST_MODEL - assert "/" in model, f"Model name should have org/repo format: {model}" - assert "tiny" in model.lower(), f"Should use tiny model for tests: {model}" - assert "gpt2" in model.lower(), f"Should use GPT2 for compatibility: {model}" - - # Verify size expectations are reasonable for integration tests - min_size, max_size = TestRealModelLifecycle.EXPECTED_SIZE_RANGE - assert min_size < max_size, "Size range should be valid" - assert max_size < 20_000_000, "Test model should be reasonably small for CI efficiency" - - print(f"✓ Integration test configuration validated: {model}") - - def test_integration_vs_unit_test_coverage(self): - """Verify integration tests cover areas missed by unit tests.""" - # This integration test should cover: - # 1. Real file system operations (not mocked) - # 2. Real path resolution logic - # 3. Real framework detection - # 4. Real lock cleanup (Issue #23) - # 5. End-to-end workflows - - # Count methods in TestRealModelLifecycle - test_methods = [method for method in dir(TestRealModelLifecycle) - if method.startswith('test_')] - - # Should have comprehensive lifecycle coverage - assert len(test_methods) >= 7, f"Should have comprehensive test coverage: {len(test_methods)} tests" - - # Should test specific functionality - method_names = ' '.join(test_methods) - assert 'download' in method_names, "Should test downloading" - assert 'list' in method_names, "Should test listing" - assert 'show' in method_names, "Should test showing" - assert 'resolve' in method_names, "Should test resolution" - assert 'health' in method_names, "Should test health checks" - assert 'rm' in method_names or 'remove' in method_names, "Should test removal" - assert 'lock' in method_names, "Should test lock cleanup (Issue #23)" - - print(f"✓ Integration tests provide comprehensive lifecycle coverage: {len(test_methods)} tests") \ No newline at end of file diff --git a/tests/integration/test_run_command_advanced.py b/tests/integration/test_run_command_advanced.py deleted file mode 100644 index b64d0d0..0000000 --- a/tests/integration/test_run_command_advanced.py +++ /dev/null @@ -1,431 +0,0 @@ -""" -Advanced Tests for Run Command - -Tests the most problematic aspects of the run command: -- Process lifecycle during model execution -- Memory management with model loading/unloading -- Streaming interruption handling -- Error conditions and recovery -""" -import pytest -import subprocess -import signal -import time -import threading -from pathlib import Path - - -@pytest.mark.timeout(120) -@pytest.mark.usefixtures("temp_cache_dir") -class TestRunCommandProcessLifecycle: - """Test process management during model execution.""" - - def test_run_command_normal_completion(self, mlx_knife_process, process_monitor, mock_model_cache): - """Test run command completes normally and cleans up.""" - # Create a mock model (won't actually run, but tests process handling) - mock_model_cache("test-model", healthy=True) - - proc = mlx_knife_process(["run", "test-model", "Hello"]) - main_pid = proc.pid - - # Track child processes - children_before = process_monitor["get_process_tree"](main_pid) - - try: - # Wait for completion (will likely fail due to mock model, but should not hang) - return_code = proc.wait(timeout=30) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command hung during execution") - - # Should complete (success or failure, but not hang) - assert return_code is not None, "Run command did not complete" - - # Verify child process cleanup - assert process_monitor["wait_for_cleanup"](main_pid, timeout=10) - - for child in children_before: - assert not child.is_running(), f"Run command left zombie process: PID {child.pid}" - - def test_run_command_sigint_during_execution(self, mlx_knife_process, process_monitor, mock_model_cache): - """Test interruption during model execution.""" - mock_model_cache("test-model", healthy=True) - - proc = mlx_knife_process(["run", "test-model", "This is a longer prompt that might take time"]) - main_pid = proc.pid - - # Give it time to start - time.sleep(2) - - children_before = process_monitor["get_process_tree"](main_pid) - - # Send interrupt - proc.send_signal(signal.SIGINT) - - try: - return_code = proc.wait(timeout=20) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command did not respond to SIGINT") - - # Should exit on interrupt - assert return_code is not None - assert return_code != 0 # Should not exit normally - - # Clean up child processes - assert process_monitor["wait_for_cleanup"](main_pid, timeout=10) - - for child in children_before: - assert not child.is_running(), f"Run child process survived SIGINT: PID {child.pid}" - - def test_run_command_sigterm_handling(self, mlx_knife_process, process_monitor, mock_model_cache): - """Test SIGTERM during model execution.""" - mock_model_cache("test-model", healthy=True) - - proc = mlx_knife_process(["run", "test-model", "Test prompt"]) - main_pid = proc.pid - - time.sleep(2) - children_before = process_monitor["get_process_tree"](main_pid) - - # Send SIGTERM - proc.send_signal(signal.SIGTERM) - - try: - return_code = proc.wait(timeout=20) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command did not respond to SIGTERM") - - assert return_code is not None - assert return_code != 0 - - # Cleanup verification - assert process_monitor["wait_for_cleanup"](main_pid, timeout=10) - - for child in children_before: - assert not child.is_running(), f"Run child survived SIGTERM: PID {child.pid}" - - def test_run_command_model_loading_failure(self, mlx_knife_process, process_monitor): - """Test process cleanup when model loading fails.""" - # Use nonexistent model to trigger loading failure - proc = mlx_knife_process(["run", "nonexistent-model-12345", "Test prompt"]) - main_pid = proc.pid - - children_before = process_monitor["get_process_tree"](main_pid) - - try: - return_code = proc.wait(timeout=20) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command hung on model loading failure") - - # Should fail gracefully - assert return_code is not None - assert return_code != 0 # Should fail on missing model - - # Should not leave zombies even on failure - assert process_monitor["wait_for_cleanup"](main_pid, timeout=5) - - for child in children_before: - assert not child.is_running(), f"Process survived model loading failure: PID {child.pid}" - - -@pytest.mark.timeout(90) -@pytest.mark.usefixtures("temp_cache_dir") -class TestRunCommandMemoryManagement: - """Test memory management during run command execution.""" - - def test_run_command_memory_cleanup_after_completion(self, mlx_knife_process, mock_model_cache): - """Test memory is released after run command completes.""" - mock_model_cache("test-model", healthy=True) - - # Run command multiple times to test memory cleanup - for i in range(3): - proc = mlx_knife_process(["run", "test-model", f"Test prompt {i}"]) - - try: - return_code = proc.wait(timeout=25) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail(f"Run command {i} hung") - - # Should complete (may fail, but should not hang) - assert return_code is not None, f"Run command {i} did not complete" - - def test_run_command_memory_cleanup_on_interruption(self, mlx_knife_process, process_monitor, mock_model_cache): - """Test memory cleanup when run is interrupted.""" - mock_model_cache("test-model", healthy=True) - - proc = mlx_knife_process(["run", "test-model", "Longer test prompt for interruption"]) - main_pid = proc.pid - - # Let it start - time.sleep(3) - - # Interrupt - proc.send_signal(signal.SIGINT) - - try: - return_code = proc.wait(timeout=15) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command did not handle interruption") - - # Verify cleanup - assert return_code is not None - assert process_monitor["wait_for_cleanup"](main_pid, timeout=10) - - def test_run_command_handles_corrupted_model(self, mlx_knife_process, mock_model_cache): - """Test run command handles corrupted models gracefully.""" - # Create corrupted model - mock_model_cache("broken-model", healthy=False, corruption_type="truncated_safetensors") - - proc = mlx_knife_process(["run", "broken-model", "Test prompt"]) - - try: - return_code = proc.wait(timeout=20) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command hung on corrupted model") - - # Should fail gracefully on corrupted model - assert return_code is not None - assert return_code != 0 # Should fail - - -@pytest.mark.timeout(60) -@pytest.mark.usefixtures("temp_cache_dir") -class TestRunCommandStreamingAndOutput: - """Test streaming and output handling in run command.""" - - def test_run_command_streaming_interruption(self, mlx_knife_process): - """Test interruption during token streaming with real MLX model.""" - test_model = "Phi-3-mini-4k-instruct-4bit" - # Use prompt that would generate substantial output - test_prompt = "Explain machine learning in detail with examples." - - proc = mlx_knife_process(["run", test_model, test_prompt]) - - # Let streaming start, then interrupt - time.sleep(3) # Allow generation to begin - - # Send interrupt signal - proc.send_signal(signal.SIGINT) - - try: - stdout, stderr = proc.communicate(timeout=15) - # Should handle interruption gracefully - assert proc.returncode is not None, "Process should terminate after interrupt" - # Should have generated some output before interruption - assert len(stdout) > 0, "Should have some output before interruption" - except subprocess.TimeoutExpired: - proc.kill() - stdout, stderr = proc.communicate() - pytest.fail("Process didn't respond to interruption signal") - - def test_run_command_output_handling(self, mlx_knife_process, mock_model_cache): - """Test that run command handles output correctly.""" - mock_model_cache("test-model", healthy=True) - - proc = mlx_knife_process(["run", "test-model", "Hello"]) - - try: - stdout, stderr = proc.communicate(timeout=20) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command hung during output") - - # Should produce some output (even if error message) - total_output = len(stdout) + len(stderr) - assert total_output > 0, "Run command produced no output" - - def test_run_command_long_prompt_handling(self, mlx_knife_process, mock_model_cache): - """Test run command with very long prompts.""" - mock_model_cache("test-model", healthy=True) - - # Create long prompt - long_prompt = "This is a test prompt. " * 100 # ~2500 characters - - proc = mlx_knife_process(["run", "test-model", long_prompt]) - - try: - return_code = proc.wait(timeout=25) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command hung on long prompt") - - # Should handle long prompt without hanging - assert return_code is not None - - def test_run_command_special_characters(self, mlx_knife_process, mock_model_cache): - """Test run command handles special characters in prompts.""" - mock_model_cache("test-model", healthy=True) - - special_prompts = [ - "Hello 世界", # Unicode - "Test with \"quotes\" and 'apostrophes'", # Quotes - "Newlines\nand\ttabs", # Whitespace - "emoji 🚀 test", # Emoji - ] - - for prompt in special_prompts: - proc = mlx_knife_process(["run", "test-model", prompt]) - - try: - return_code = proc.wait(timeout=20) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail(f"Run command hung on special characters: {prompt[:20]}...") - - # Should handle special characters gracefully - assert return_code is not None - - -@pytest.mark.timeout(45) -@pytest.mark.usefixtures("temp_cache_dir") -class TestRunCommandErrorConditions: - """Test run command error handling.""" - - def test_run_command_insufficient_memory(self, mlx_knife_process, mock_model_cache): - """Test behavior when system might be low on memory.""" - mock_model_cache("large-model", healthy=True) - - # We can't actually simulate low memory, but we can test the process handles errors - proc = mlx_knife_process(["run", "large-model", "Test prompt"]) - - try: - return_code = proc.wait(timeout=25) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command hung during error condition") - - # Should complete (success or failure) - assert return_code is not None - - def test_run_command_missing_dependencies(self, mlx_knife_process): - """Test run command when model dependencies might be missing.""" - # Try to run with invalid model to test error handling - proc = mlx_knife_process(["run", "invalid/missing-model", "Test"]) - - try: - return_code = proc.wait(timeout=15) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command hung on missing dependencies") - - # Should fail gracefully - assert return_code is not None - assert return_code != 0 - - def test_run_command_multiple_concurrent_executions(self, mlx_knife_process, mock_model_cache): - """Test multiple concurrent run commands don't interfere.""" - mock_model_cache("test-model", healthy=True) - - processes = [] - - # Start multiple run commands - for i in range(3): - proc = mlx_knife_process(["run", "test-model", f"Concurrent test {i}"]) - processes.append(proc) - - # Wait for all to complete - for i, proc in enumerate(processes): - try: - return_code = proc.wait(timeout=30) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail(f"Concurrent run command {i} hung") - - # Each should complete independently - assert return_code is not None, f"Concurrent run {i} did not complete" - - -@pytest.mark.timeout(60) -@pytest.mark.usefixtures("temp_cache_dir") -class TestRunCommandContextAwareLimits: - """Test context-aware token limits in Issues #15 and #16 resolution.""" - - def test_context_length_extraction_from_real_model(self, mlx_knife_process, mock_model_cache): - """Test that context length is correctly extracted from real model configs.""" - # Create a mock model with realistic config.json - model_path = mock_model_cache("test-model", healthy=True) - - # Add custom config.json with specific context length - config_content = { - "max_position_embeddings": 4096, - "hidden_size": 768, - "num_attention_heads": 12 - } - import json - (model_path / "config.json").write_text(json.dumps(config_content)) - - # Test that the model context length is accessible - # This is an indirect test - we test that the run command uses model-aware limits - # by checking that it doesn't hang with realistic models - proc = mlx_knife_process([ - "run", "test-model", "Test prompt", - "--max-tokens", "8000", # Request more than typical model context - "--verbose" - ]) - - try: - # Should complete within timeout (won't actually generate due to mock) - return_code = proc.wait(timeout=30) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Run command hung with high max-tokens") - - # Should complete (may fail due to mock model, but shouldn't hang) - assert return_code is not None, "Run command did not complete with context-aware limits" - - def test_server_vs_interactive_token_policies(self, mock_model_cache): - """Test that server mode uses DoS protection while interactive mode uses full context.""" - # This test validates the architectural decision: - # - Server mode: context_length / 2 (DoS protection) - # - Interactive mode: full context_length - - from mlx_knife.mlx_runner import MLXRunner, get_model_context_length - import tempfile - import json - import os - - # Create a temporary model directory with config - with tempfile.TemporaryDirectory() as temp_dir: - config_path = os.path.join(temp_dir, "config.json") - config = {"max_position_embeddings": 4096} - - with open(config_path, 'w') as f: - json.dump(config, f) - - # Test context length extraction - context_length = get_model_context_length(temp_dir) - assert context_length == 4096, "Context length extraction failed" - - # Test MLXRunner effective token calculation - runner = MLXRunner(temp_dir, verbose=False) - runner._context_length = 4096 - - # Interactive mode should use full context - interactive_tokens = runner.get_effective_max_tokens(8000, interactive=True) - assert interactive_tokens == 4096, f"Interactive mode should use full context: {interactive_tokens}" - - # Server mode should use half context (DoS protection) - server_tokens = runner.get_effective_max_tokens(8000, interactive=False) - assert server_tokens == 2048, f"Server mode should use half context: {server_tokens}" - - # User requests smaller than limits should be honored - small_interactive = runner.get_effective_max_tokens(1000, interactive=True) - assert small_interactive == 1000, "Small requests should be honored in interactive mode" - - small_server = runner.get_effective_max_tokens(1000, interactive=False) - assert small_server == 1000, "Small requests should be honored in server mode" - - # Test None behavior (new CLI default=None logic) - # Interactive mode with None should use full context - none_interactive = runner.get_effective_max_tokens(None, interactive=True) - assert none_interactive == 4096, "None in interactive mode should use full context" - - # Server mode with None should use server limit - none_server = runner.get_effective_max_tokens(None, interactive=False) - assert none_server == 2048, "None in server mode should use server limit (context/2)" \ No newline at end of file diff --git a/tests/integration/test_server_functionality.py b/tests/integration/test_server_functionality.py deleted file mode 100644 index 79c09f0..0000000 --- a/tests/integration/test_server_functionality.py +++ /dev/null @@ -1,555 +0,0 @@ -""" -High Priority Tests: Server Functionality - -Tests for the OpenAI-compatible API server: -- Server startup and shutdown -- Process lifecycle during server operations -- API endpoint availability -- Request handling and response format -- Server interruption and cleanup -""" -import pytest -import subprocess -import time -import requests -import signal -import json -from pathlib import Path - - -@pytest.mark.timeout(60) -class TestServerLifecycle: - """Test server startup, operation, and shutdown.""" - - def test_server_startup_shutdown(self, mlx_knife_process, process_monitor): - """Test server starts and shuts down cleanly.""" - # Start server - proc = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8001"]) - main_pid = proc.pid - - # Give server time to start - time.sleep(3) - - # Check if server is responsive (basic health check) - try: - response = requests.get("http://127.0.0.1:8001/health", timeout=5) - server_started = response.status_code == 200 - except requests.exceptions.RequestException: - # Server might not have health endpoint, that's OK - server_started = proc.poll() is None # Process still running - - # Track child processes - children_before = process_monitor["get_process_tree"](main_pid) - - # Shutdown server - proc.send_signal(signal.SIGINT) - - try: - return_code = proc.wait(timeout=15) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Server did not shutdown within timeout") - - # Verify clean shutdown - assert return_code is not None, "Server process did not terminate" - - # Verify all child processes cleaned up - assert process_monitor["wait_for_cleanup"](main_pid, timeout=10) - - for child in children_before: - assert not child.is_running(), f"Server child process survived: PID {child.pid}" - - def test_server_sigterm_handling(self, mlx_knife_process, process_monitor): - """Test server responds to SIGTERM gracefully.""" - proc = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8002"]) - main_pid = proc.pid - - time.sleep(3) - children_before = process_monitor["get_process_tree"](main_pid) - - # Send SIGTERM - proc.send_signal(signal.SIGTERM) - - try: - return_code = proc.wait(timeout=15) - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail("Server did not respond to SIGTERM") - - # Should exit gracefully - assert return_code is not None - - # Clean up child processes - assert process_monitor["wait_for_cleanup"](main_pid, timeout=10) - - for child in children_before: - assert not child.is_running(), f"Server child survived SIGTERM: PID {child.pid}" - - def test_server_sigkill_cleanup(self, mlx_knife_process, process_monitor): - """Test cleanup after SIGKILL.""" - proc = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8003"]) - main_pid = proc.pid - - time.sleep(3) - children_before = process_monitor["get_process_tree"](main_pid) - - # SIGKILL should kill immediately - proc.send_signal(signal.SIGKILL) - - try: - return_code = proc.wait(timeout=10) - except subprocess.TimeoutExpired: - pytest.fail("Process did not die from SIGKILL") - - assert return_code == -signal.SIGKILL - - # Child processes should be cleaned up by OS - assert process_monitor["wait_for_cleanup"](main_pid, timeout=10) - - def test_server_port_binding_conflicts(self, mlx_knife_process): - """Test server handles port conflicts gracefully.""" - # Start first server - proc1 = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8004"]) - time.sleep(3) - - # Try to start second server on same port - proc2 = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8004"]) - - try: - # Second server should fail quickly - return_code2 = proc2.wait(timeout=10) - assert return_code2 != 0, "Second server should fail on port conflict" - except subprocess.TimeoutExpired: - proc2.kill() - pytest.fail("Second server did not fail quickly on port conflict") - finally: - # Clean up first server - if proc1.poll() is None: - proc1.send_signal(signal.SIGINT) - proc1.wait(timeout=10) - - def test_server_invalid_arguments(self, mlx_knife_process): - """Test server handles invalid arguments gracefully.""" - invalid_configs = [ - ["server", "--port", "99999"], # Invalid port - ["server", "--host", "invalid-host"], # Invalid host - ["server", "--max-tokens", "-1"], # Invalid max tokens - ] - - for config in invalid_configs: - proc = mlx_knife_process(config) - try: - return_code = proc.wait(timeout=10) - # Should fail gracefully, not hang - assert return_code is not None, f"Server hung on invalid config: {config}" - assert return_code != 0, f"Server should fail on invalid config: {config}" - except subprocess.TimeoutExpired: - proc.kill() - pytest.fail(f"Server hung on invalid config: {config}") - - -@pytest.mark.timeout(90) -class TestServerAPI: - """Test server API functionality.""" - - def test_server_health_endpoint(self, mlx_knife_process): - """Test server health/status endpoint if available.""" - proc = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8005"]) - - # Wait for server to start - time.sleep(4) - - try: - # Try common health endpoints - health_endpoints = [ - "http://127.0.0.1:8005/health", - "http://127.0.0.1:8005/v1/models", - "http://127.0.0.1:8005/", - ] - - server_responsive = False - for endpoint in health_endpoints: - try: - response = requests.get(endpoint, timeout=5) - if response.status_code in [200, 404]: # 404 is OK, means server is running - server_responsive = True - break - except requests.exceptions.RequestException: - continue - - # Server should be responsive to at least one endpoint - assert server_responsive, "Server not responsive to any health endpoints" - - finally: - # Clean up - if proc.poll() is None: - proc.send_signal(signal.SIGINT) - proc.wait(timeout=15) - - def test_server_openai_models_endpoint(self, mlx_knife_process): - """Test OpenAI-compatible /v1/models endpoint.""" - proc = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8006"]) - - time.sleep(4) - - try: - response = requests.get("http://127.0.0.1:8006/v1/models", timeout=10) - - # Should respond (may be empty list if no models) - assert response.status_code == 200, f"Models endpoint failed: {response.status_code}" - - # Should return valid JSON - try: - data = response.json() - assert isinstance(data, dict), "Models endpoint should return JSON object" - # OpenAI format typically has 'data' field - if 'data' in data: - assert isinstance(data['data'], list), "Models data should be a list" - except json.JSONDecodeError: - pytest.fail("Models endpoint returned invalid JSON") - - except requests.exceptions.RequestException as e: - pytest.fail(f"Failed to connect to models endpoint: {e}") - finally: - if proc.poll() is None: - proc.send_signal(signal.SIGINT) - proc.wait(timeout=15) - - def test_server_chat_completions_endpoint(self, mlx_knife_process): - """Test OpenAI-compatible /v1/chat/completions endpoint structure.""" - proc = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8007"]) - - time.sleep(4) - - try: - # Test with minimal valid request - payload = { - "model": "test-model", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10 - } - - response = requests.post( - "http://127.0.0.1:8007/v1/chat/completions", - json=payload, - timeout=15 - ) - - # Should respond (may be error if no models, but shouldn't hang) - assert response.status_code is not None, "Chat completions endpoint hung" - - # Should return JSON response - try: - data = response.json() - assert isinstance(data, dict), "Chat completions should return JSON object" - - if response.status_code == 200: - # Valid response should have expected fields - assert 'choices' in data or 'error' in data - elif response.status_code == 400: - # Bad request should have error message - assert 'error' in data - - except json.JSONDecodeError: - pytest.fail("Chat completions returned invalid JSON") - - except requests.exceptions.RequestException as e: - pytest.fail(f"Failed to connect to chat completions endpoint: {e}") - finally: - if proc.poll() is None: - proc.send_signal(signal.SIGINT) - proc.wait(timeout=15) - - @pytest.mark.server - def test_issue_19_server_token_limits_regression(self, mlx_knife_process): - """ - Regression test for Issue #19: Server output truncation at ~1000 words. - - Tests that server respects --max-tokens parameter and doesn't truncate - responses prematurely due to hardcoded 2000 token default. - """ - # Test with low max-tokens (should truncate early) - proc_low = mlx_knife_process([ - "server", "--host", "127.0.0.1", "--port", "8008", - "--max-tokens", "100" # Very low limit - ]) - - time.sleep(4) - - try: - # Long-form prompt that should trigger Issue #19 behavior - # Based on real user scenario that exposed the original truncation bug - trilogy_prompt = """Here is the outline for a fantasy trilogy "EMBERS OF THE FORGOTTEN": - -**MAIN CHARACTERS:** -1. Kaelen Veyra - The Exiled Flame Herald (32, war poet, controls Soulfire) -2. Sylra D'Tharn - The Shadow Warrior (28, assassin, uses Emotionweave) -3. Lord Morvath - The Unforgotten King (45, tragic villain with Grief-Crown) - -**TRILOGY STRUCTURE:** -- Book I: "Embers of the Forgotten" - The flame that remembers -- Book II: "The Lovers' Crucible" - The fire that doesn't burn -- Book III: "The Fire That Binds" - The flame that connects - -**THEMES:** Love as power not weakness, memory as healing, emotions as connection - -**YOUR TASK:** Write the complete first chapter of Book I: "The Poet Who Burned" -- Focus on Kaelen's exile from Celestine after his beloved Lirien's execution -- Include his arrival at Veyra (Valley of Faces) with 30 lost masks -- Show his Soulfire powers and emotional depth -- Use poetic, mythic language with deep inner rhythm -- Target: 2000+ words with full character development and dialogue -- End with the mysterious mask whispering: "You were here - a thousand years ago" - -Write the complete chapter now.""" - - payload_long = { - "model": "test-model", - "messages": [{"role": "user", "content": trilogy_prompt}], - "stream": False, - "temperature": 0.7 - } - - response_low = requests.post( - "http://127.0.0.1:8008/v1/chat/completions", - json=payload_long, - timeout=30 - ) - - # Should respond with some content but truncated - if response_low.status_code == 200: - data_low = response_low.json() - if 'choices' in data_low and data_low['choices']: - content_low = data_low['choices'][0].get('message', {}).get('content', '') - # With max-tokens=100, content should be short - assert len(content_low.split()) < 200, f"Low token limit not enforced: {len(content_low.split())} words" - - except (requests.exceptions.RequestException, json.JSONDecodeError): - # If no model available, test structure is still validated - pass - finally: - if proc_low.poll() is None: - proc_low.send_signal(signal.SIGINT) - proc_low.wait(timeout=15) - - # Test with high max-tokens (should allow longer responses) - proc_high = mlx_knife_process([ - "server", "--host", "127.0.0.1", "--port", "8009", - "--max-tokens", "10000" # High limit - ]) - - time.sleep(4) - - try: - response_high = requests.post( - "http://127.0.0.1:8009/v1/chat/completions", - json=payload_long, - timeout=60 - ) - - # Should allow longer responses - if response_high.status_code == 200: - data_high = response_high.json() - if 'choices' in data_high and data_high['choices']: - content_high = data_high['choices'][0].get('message', {}).get('content', '') - # High token limit should allow more content (if model available) - # This test validates server respects the --max-tokens parameter - assert isinstance(content_high, str), "Response content should be string" - - except (requests.exceptions.RequestException, json.JSONDecodeError): - pass - finally: - if proc_high.poll() is None: - proc_high.send_signal(signal.SIGINT) - proc_high.wait(timeout=15) - - def test_server_startup_token_limit_messages(self, mlx_knife_process): - """Test that server startup shows correct token limit configuration.""" - # Test default (None) shows dynamic limits message - proc_default = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8010"]) - time.sleep(4) - - try: - # Stop server first to avoid blocking read - if proc_default.poll() is None: - proc_default.send_signal(signal.SIGINT) - proc_default.wait(timeout=15) - - # Now safely read stdout after server shutdown - stdout_data = proc_default.stdout.read() if proc_default.stdout else b"" - stdout_text = stdout_data.decode('utf-8', errors='ignore') - - # Should show dynamic limits message when no --max-tokens specified - if stdout_text.strip(): # Only check if we got output - assert "model-aware dynamic limits" in stdout_text, f"Expected dynamic limits message, got: {stdout_text}" - - except Exception: - # If no stdout capture available, test passes (infrastructure limitation) - pass - - # Test explicit --max-tokens shows numeric value - proc_explicit = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8011", "--max-tokens", "5000"]) - time.sleep(4) - - try: - # Stop server first to avoid blocking read - if proc_explicit.poll() is None: - proc_explicit.send_signal(signal.SIGINT) - proc_explicit.wait(timeout=15) - - # Now safely read stdout after server shutdown - stdout_data = proc_explicit.stdout.read() if proc_explicit.stdout else b"" - stdout_text = stdout_data.decode('utf-8', errors='ignore') - - # Should show explicit numeric value - if stdout_text.strip(): # Only check if we got output - assert "5000" in stdout_text, f"Expected '5000' in startup message, got: {stdout_text}" - - except Exception: - pass - - def test_server_streaming_endpoint(self, mlx_knife_process): - """Test streaming functionality if available.""" - proc = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8008"]) - - time.sleep(4) - - try: - # Test streaming request - payload = { - "model": "test-model", - "messages": [{"role": "user", "content": "Hi"}], - "max_tokens": 5, - "stream": True - } - - response = requests.post( - "http://127.0.0.1:8008/v1/chat/completions", - json=payload, - timeout=20, - stream=True - ) - - # Should respond to streaming request - assert response.status_code is not None, "Streaming endpoint hung" - - # Should handle streaming gracefully (may error if no model) - if response.status_code == 200: - # Should return SSE format or similar - assert 'text/plain' in response.headers.get('content-type', '') or \ - 'text/event-stream' in response.headers.get('content-type', '') or \ - 'application/json' in response.headers.get('content-type', '') - - except requests.exceptions.RequestException as e: - pytest.fail(f"Streaming endpoint connection failed: {e}") - finally: - if proc.poll() is None: - proc.send_signal(signal.SIGINT) - proc.wait(timeout=15) - - -@pytest.mark.timeout(45) -class TestServerResourceManagement: - """Test server resource management.""" - - def test_server_memory_cleanup_after_shutdown(self, mlx_knife_process): - """Test that server cleans up memory after shutdown.""" - # Start and stop server multiple times - for i in range(3): - proc = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", f"800{9+i}"]) - - time.sleep(2) - - # Shutdown cleanly - proc.send_signal(signal.SIGINT) - return_code = proc.wait(timeout=15) - - assert return_code is not None, f"Server {i} did not shutdown" - - def test_server_handles_multiple_requests(self, mlx_knife_process): - """Test server can handle multiple concurrent requests without hanging.""" - proc = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8012"]) - - time.sleep(4) - - try: - # Send multiple requests concurrently - import threading - import queue - - results = queue.Queue() - - def make_request(endpoint): - try: - response = requests.get(f"http://127.0.0.1:8012{endpoint}", timeout=10) - results.put(("success", response.status_code)) - except Exception as e: - results.put(("error", str(e))) - - # Start multiple threads - threads = [] - endpoints = ["/v1/models", "/v1/models", "/v1/models"] - - for endpoint in endpoints: - thread = threading.Thread(target=make_request, args=(endpoint,)) - threads.append(thread) - thread.start() - - # Wait for all threads - for thread in threads: - thread.join(timeout=20) - assert not thread.is_alive(), "Request thread hung" - - # Check results - success_count = 0 - while not results.empty(): - result_type, result_value = results.get() - if result_type == "success": - success_count += 1 - - # At least some requests should succeed - assert success_count > 0, "No requests succeeded" - - finally: - if proc.poll() is None: - proc.send_signal(signal.SIGINT) - proc.wait(timeout=15) - - def test_server_request_interruption(self, mlx_knife_process): - """Test server handles request interruption cleanly.""" - proc = mlx_knife_process(["server", "--host", "127.0.0.1", "--port", "8013"]) - - time.sleep(4) - - try: - # Start a request and interrupt it - import threading - - def make_slow_request(): - try: - requests.get("http://127.0.0.1:8013/v1/models", timeout=2) - except: - pass # Expected to timeout/fail - - # Start request in background - request_thread = threading.Thread(target=make_slow_request) - request_thread.start() - - # Give request time to start - time.sleep(1) - - # Shutdown server while request is in progress - proc.send_signal(signal.SIGINT) - return_code = proc.wait(timeout=15) - - # Server should shutdown cleanly even with active requests - assert return_code is not None, "Server hung during request interruption" - - # Request thread should complete - request_thread.join(timeout=10) - assert not request_thread.is_alive(), "Request thread hung after server shutdown" - - finally: - if proc.poll() is None: - proc.kill() - proc.wait() \ No newline at end of file diff --git a/tests/unit/test_cache_utils.py b/tests/unit/test_cache_utils.py deleted file mode 100644 index 84641c6..0000000 --- a/tests/unit/test_cache_utils.py +++ /dev/null @@ -1,902 +0,0 @@ -""" -Unit tests for cache_utils.py module. - -Tests the core model management functions: -- Model discovery and metadata extraction -- Health checking logic -- Cache operations -""" -import pytest -import tempfile -import shutil -import json -from pathlib import Path -from unittest.mock import patch, MagicMock, call - -# Import the module under test -from mlx_knife.cache_utils import ( - expand_model_name, - hf_to_cache_dir, - cache_dir_to_hf, - is_model_healthy, - detect_framework, - list_models, - find_matching_models, - resolve_single_model -) - - -class TestModelNameExpansion: - """Test model name expansion logic.""" - - def test_expand_short_names(self): - """Test expansion of common short model names.""" - test_cases = [ - ("Phi-3-mini", "mlx-community/Phi-3-mini-4k-instruct-4bit"), - ("Mistral-7B", "mlx-community/Mistral-7B-Instruct-v0.3-4bit"), - ("Llama-3-8B", "mlx-community/Meta-Llama-3-8B-Instruct-4bit"), - ] - - for short_name, expected in test_cases: - try: - result = expand_model_name(short_name) - # Should either expand correctly or return the original name - assert isinstance(result, str) - assert len(result) > 0 - except Exception as e: - pytest.fail(f"expand_model_name failed for {short_name}: {e}") - - def test_expand_full_names(self): - """Test that full model names are returned unchanged.""" - full_names = [ - "mlx-community/Phi-3-mini-4k-instruct-4bit", - "microsoft/Phi-3-mini-4k-instruct", - "meta-llama/Llama-2-7b-chat-hf" - ] - - for full_name in full_names: - try: - result = expand_model_name(full_name) - # Should return the name as-is or expand it - assert isinstance(result, str) - assert len(result) > 0 - except Exception as e: - pytest.fail(f"expand_model_name failed for {full_name}: {e}") - - def test_expand_invalid_names(self): - """Test handling of invalid or nonsense model names.""" - invalid_names = [ - "definitely-not-a-model-12345", - "", - " ", - "invalid/model/with/too/many/slashes" - ] - - for invalid_name in invalid_names: - try: - result = expand_model_name(invalid_name) - # Should handle gracefully - either return input or raise appropriate error - if result is not None: - assert isinstance(result, str) - except Exception: - # It's OK to raise exceptions for invalid names - pass - - -class TestCacheDirectoryConversion: - """Test cache directory name conversion functions.""" - - def test_hf_to_cache_dir(self): - """Test HuggingFace model name to cache directory conversion.""" - test_cases = [ - ("microsoft/Phi-3-mini-4k-instruct", "models--microsoft--Phi-3-mini-4k-instruct"), - ("meta-llama/Llama-2-7b", "models--meta-llama--Llama-2-7b"), - ("simple-model", "models--simple-model"), - ] - - for hf_name, expected_cache_dir in test_cases: - try: - result = hf_to_cache_dir(hf_name) - assert isinstance(result, str) - # Should follow HF cache naming convention - assert result.startswith("models--") - assert "--" in result - except Exception as e: - pytest.fail(f"hf_to_cache_dir failed for {hf_name}: {e}") - - def test_cache_dir_to_hf(self): - """Test cache directory to HuggingFace model name conversion.""" - test_cases = [ - ("models--microsoft--Phi-3-mini-4k-instruct", "microsoft/Phi-3-mini-4k-instruct"), - ("models--meta-llama--Llama-2-7b", "meta-llama/Llama-2-7b"), - ("models--simple-model", "simple-model"), - ] - - for cache_dir, expected_hf_name in test_cases: - try: - result = cache_dir_to_hf(cache_dir) - assert isinstance(result, str) - # Should reverse the cache directory format - assert "/" in result or len(result.split("--")) == 1 - except Exception as e: - pytest.fail(f"cache_dir_to_hf failed for {cache_dir}: {e}") - - def test_round_trip_conversion(self): - """Test that conversion functions are inverses.""" - test_names = [ - "microsoft/Phi-3-mini-4k-instruct", - "simple-model", - "org/model-name-with-dashes" - ] - - for original_name in test_names: - try: - cache_dir = hf_to_cache_dir(original_name) - recovered_name = cache_dir_to_hf(cache_dir) - - assert recovered_name == original_name, \ - f"Round trip failed: {original_name} -> {cache_dir} -> {recovered_name}" - except Exception as e: - pytest.fail(f"Round trip conversion failed for {original_name}: {e}") - - -class TestModelHealthCheck: - """Test model health checking logic.""" - - def test_healthy_model_structure(self, temp_cache_dir): - """Test health check on properly structured model.""" - # Create a healthy model structure - model_dir = temp_cache_dir / "models--test--model" / "snapshots" / "main" - model_dir.mkdir(parents=True) - - # Create required files - (model_dir / "config.json").write_text('{"model_type": "test", "architectures": ["TestModel"]}') - (model_dir / "tokenizer.json").write_text('{"version": "1.0", "tokenizer": {}}') - (model_dir / "model.safetensors").write_bytes(b"fake_model_weights" * 100) - - try: - is_healthy = is_model_healthy(str(model_dir)) - # Should be True for healthy model - assert isinstance(is_healthy, bool) - except Exception as e: - pytest.fail(f"Health check failed on healthy model: {e}") - - def test_missing_config_detection(self, temp_cache_dir): - """Test detection of missing config.json.""" - model_dir = temp_cache_dir / "models--test--model" / "snapshots" / "main" - model_dir.mkdir(parents=True) - - # Missing config.json - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - (model_dir / "model.safetensors").write_bytes(b"fake_weights") - - try: - is_healthy = is_model_healthy(str(model_dir)) - # Should detect missing config - assert isinstance(is_healthy, bool) - # Likely should be False, but depends on implementation - except Exception as e: - # It's OK to raise exception for missing config - pass - - def test_missing_tokenizer_detection(self, temp_cache_dir): - """Test detection of missing tokenizer.json.""" - model_dir = temp_cache_dir / "models--test--model" / "snapshots" / "main" - model_dir.mkdir(parents=True) - - # Missing tokenizer.json - (model_dir / "config.json").write_text('{"model_type": "test"}') - (model_dir / "model.safetensors").write_bytes(b"fake_weights") - - try: - is_healthy = is_model_healthy(str(model_dir)) - assert isinstance(is_healthy, bool) - except Exception as e: - # OK to raise exception for missing tokenizer - pass - - def test_missing_model_weights(self, temp_cache_dir): - """Test detection of missing model weights.""" - model_dir = temp_cache_dir / "models--test--model" / "snapshots" / "main" - model_dir.mkdir(parents=True) - - # Missing model files - (model_dir / "config.json").write_text('{"model_type": "test"}') - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - # No .safetensors files - - try: - is_healthy = is_model_healthy(str(model_dir)) - assert isinstance(is_healthy, bool) - except Exception as e: - # OK to raise exception for missing weights - pass - - def test_lfs_pointer_detection(self, temp_cache_dir): - """Test detection of LFS pointer files.""" - model_dir = temp_cache_dir / "models--test--model" / "snapshots" / "main" - model_dir.mkdir(parents=True) - - (model_dir / "config.json").write_text('{"model_type": "test"}') - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - - # Create LFS pointer file instead of actual weights - lfs_content = ( - "version https://git-lfs.github.com/spec/v1\n" - "oid sha256:abc123def456\n" - "size 1000000000\n" - ) - (model_dir / "model.safetensors").write_text(lfs_content) - - try: - is_healthy = is_model_healthy(str(model_dir)) - # Should detect LFS pointer as unhealthy - assert isinstance(is_healthy, bool) - except Exception as e: - # OK to raise exception for LFS pointers - pass - - def test_nonexistent_directory(self): - """Test health check on nonexistent directory.""" - nonexistent_path = "/this/path/definitely/does/not/exist" - - try: - is_healthy = is_model_healthy(nonexistent_path) - # Should handle gracefully - assert isinstance(is_healthy, bool) - assert is_healthy is False # Nonexistent should be unhealthy - except Exception: - # OK to raise exception for nonexistent path - pass - - -class TestFrameworkDetection: - """Test model framework detection logic.""" - - def test_mlx_model_detection(self, temp_cache_dir): - """Test detection of MLX-compatible models.""" - model_dir = temp_cache_dir / "models--mlx-community--test-model" / "snapshots" / "main" - model_dir.mkdir(parents=True) - - # Create MLX model config - mlx_config = { - "model_type": "llama", - "architectures": ["LlamaForCausalLM"], - "quantization": {"group_size": 64, "bits": 4} # MLX quantization - } - (model_dir / "config.json").write_text(json.dumps(mlx_config)) - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - (model_dir / "model.safetensors").write_bytes(b"mlx_weights") - - try: - from pathlib import Path - framework = detect_framework(Path(str(model_dir)), "mlx-community/test-model") - assert isinstance(framework, str) - # Should detect as MLX or compatible - except Exception as e: - pytest.fail(f"Framework detection failed on MLX model: {e}") - - def test_pytorch_model_detection(self, temp_cache_dir): - """Test detection of PyTorch models.""" - model_dir = temp_cache_dir / "models--pytorch--test-model" / "snapshots" / "main" - model_dir.mkdir(parents=True) - - # Create PyTorch model config - pytorch_config = { - "model_type": "bert", - "architectures": ["BertForSequenceClassification"], - "torch_dtype": "float32" - } - (model_dir / "config.json").write_text(json.dumps(pytorch_config)) - (model_dir / "tokenizer.json").write_text('{"version": "1.0"}') - (model_dir / "pytorch_model.bin").write_bytes(b"pytorch_weights") - - try: - from pathlib import Path - framework = detect_framework(Path(str(model_dir)), "pytorch/test-model") - assert isinstance(framework, str) - except Exception as e: - pytest.fail(f"Framework detection failed on PyTorch model: {e}") - - -class TestModelListing: - """Test model listing functionality.""" - - @patch('mlx_knife.cache_utils.MODEL_CACHE') - def test_list_models_empty_cache(self, mock_cache, temp_cache_dir): - """Test model listing in empty cache.""" - mock_cache.__str__ = lambda: str(temp_cache_dir) - mock_cache.exists.return_value = True - mock_cache.glob.return_value = [] - - try: - # list_models prints to stdout, so we test it doesn't crash - list_models(verbose=False) - except Exception as e: - pytest.fail(f"Model listing failed on empty cache: {e}") - - def test_list_models_real_empty_cache(self, temp_cache_dir): - """Test Issue #21: list_models with real empty HF_HOME directory.""" - import os - from mlx_knife.cache_utils import list_models - - # Create empty cache directory - empty_cache = temp_cache_dir / "empty_hf_cache" - empty_cache.mkdir() - - # Set HF_HOME to empty directory and test - original_hf_home = os.environ.get('HF_HOME') - try: - os.environ['HF_HOME'] = str(empty_cache) - # Should not crash and should print helpful message - list_models() - except FileNotFoundError as e: - pytest.fail(f"Issue #21 regression: list_models crashed with empty cache: {e}") - finally: - if original_hf_home is not None: - os.environ['HF_HOME'] = original_hf_home - elif 'HF_HOME' in os.environ: - del os.environ['HF_HOME'] - - @patch('mlx_knife.cache_utils.MODEL_CACHE') - def test_list_models_basic_call(self, mock_cache, temp_cache_dir): - """Test basic model listing call.""" - mock_cache.__str__ = lambda: str(temp_cache_dir) - mock_cache.exists.return_value = True - mock_cache.glob.return_value = [] - - try: - # Test various parameter combinations - list_models(show_all=True) - list_models(framework_filter="MLX") - list_models(show_health=True) - except Exception as e: - pytest.fail(f"Model listing with parameters failed: {e}") - - -class TestModelRemoval: - """Test rm_model functionality (Issue #23).""" - - def setup_method(self): - """Setup mock cache structure for each test.""" - self.test_model_name = "microsoft/DialoGPT-small" - self.test_hash = "49c537161a457d5256512f9d2d38a87d81ae0f0e" - self.test_hash_short = "49c53716" - - @patch('mlx_knife.cache_utils.MODEL_CACHE') - @patch('mlx_knife.cache_utils.resolve_single_model') - @patch('mlx_knife.cache_utils.shutil.rmtree') - @patch('builtins.input', return_value='y') - def test_rm_model_fixed_behavior_issue23(self, mock_input, mock_rmtree, mock_resolve, mock_cache, temp_cache_dir): - """Test fixed rm behavior - should delete model AND locks (Issue #23 resolved). - - Setup mocked directory structure as documented in CLAUDE.md: - hub/ - ├── .locks/models--/ # Per-model lock files - └── models--/ # Model data directory - ├── blobs/ # Deduplicated file storage - ├── refs/main # Points to current commit hash - └── snapshots// # Specific version - """ - from mlx_knife.cache_utils import rm_model - - # Create real temp directories that mirror HF cache structure - # After fix: MODEL_CACHE points to hub/, locks are at hub/.locks/ - hub_dir = temp_cache_dir / "hub" - model_dir = hub_dir / "models--microsoft--DialoGPT-small" - snapshots_dir = model_dir / "snapshots" - hash_dir = snapshots_dir / self.test_hash_short - refs_dir = model_dir / "refs" - blobs_dir = model_dir / "blobs" - locks_dir = hub_dir / ".locks" / "models--microsoft--DialoGPT-small" - - # Create the directory structure (but don't populate with real files) - hash_dir.mkdir(parents=True) - refs_dir.mkdir(parents=True) - blobs_dir.mkdir(parents=True) - locks_dir.mkdir(parents=True) - - # Create refs/main file pointing to hash - (refs_dir / "main").write_text(self.test_hash_short) - - # Create some mock lock files - (locks_dir / "file1.lock").touch() - (locks_dir / "file2.lock").touch() - - # Mock resolve_single_model to return our temp structure - mock_resolve.return_value = (model_dir, self.test_model_name, self.test_hash_short) - - # Mock MODEL_CACHE to point to hub directory (after fix: locks are at MODEL_CACHE/.locks/) - import mlx_knife.cache_utils - mlx_knife.cache_utils.MODEL_CACHE = hub_dir - - # Verify our test structure exists - assert model_dir.exists() - assert hash_dir.exists() - assert (refs_dir / "main").exists() - assert locks_dir.exists() - assert len(list(locks_dir.iterdir())) == 2 - - # Test current rm behavior - this should show Issue #23 - rm_model(f"{self.test_model_name}@{self.test_hash_short}") - - # Verify what was actually deleted - # Fixed behavior: should delete model directory AND locks directory - assert mock_rmtree.call_count == 2 - - # Verify both calls: model directory and locks directory - calls = [call[0][0] for call in mock_rmtree.call_args_list] - model_call = next((call for call in calls if "models--microsoft--DialoGPT-small" in str(call) and ".locks" not in str(call)), None) - locks_call = next((call for call in calls if ".locks" in str(call)), None) - - assert model_call is not None, "Should delete model directory" - assert locks_call is not None, "Should delete locks directory" - - @patch('mlx_knife.cache_utils.MODEL_CACHE') - @patch('mlx_knife.cache_utils.resolve_single_model') - @patch('mlx_knife.cache_utils.shutil.rmtree') - def test_rm_model_force_parameter(self, mock_rmtree, mock_resolve, mock_cache, temp_cache_dir): - """Test rm_model with force=True skips all confirmations.""" - from mlx_knife.cache_utils import rm_model - - # Create same temp structure as previous test (updated for fix) - hub_dir = temp_cache_dir / "hub" - model_dir = hub_dir / "models--microsoft--DialoGPT-small" - snapshots_dir = model_dir / "snapshots" - hash_dir = snapshots_dir / self.test_hash_short - locks_dir = hub_dir / ".locks" / "models--microsoft--DialoGPT-small" - - # Create the directory structure - hash_dir.mkdir(parents=True) - locks_dir.mkdir(parents=True) - (locks_dir / "file1.lock").touch() - (locks_dir / "file2.lock").touch() - - # Mock resolve_single_model to return our temp structure - mock_resolve.return_value = (model_dir, self.test_model_name, self.test_hash_short) - - # Mock MODEL_CACHE to point to hub directory (after fix) - import mlx_knife.cache_utils - mlx_knife.cache_utils.MODEL_CACHE = hub_dir - - # Test with force=True - should NOT call input() at all - with patch('builtins.input') as mock_input: - rm_model(f"{self.test_model_name}@{self.test_hash_short}", force=True) - - # Verify input() was never called (no prompts with force=True) - mock_input.assert_not_called() - - # Verify both model and locks were deleted - assert mock_rmtree.call_count == 2 - calls = [call[0][0] for call in mock_rmtree.call_args_list] - model_call = next((call for call in calls if "models--microsoft--DialoGPT-small" in str(call) and ".locks" not in str(call)), None) - locks_call = next((call for call in calls if ".locks" in str(call)), None) - - assert model_call is not None, "Should delete model directory with force=True" - assert locks_call is not None, "Should delete locks directory with force=True" - - @patch('mlx_knife.cache_utils.MODEL_CACHE') - @patch('mlx_knife.cache_utils.resolve_single_model') - @patch('mlx_knife.cache_utils.shutil.rmtree') - def test_rm_model_force_vs_interactive(self, mock_rmtree, mock_resolve, mock_cache, temp_cache_dir): - """Test that force=True behaves differently than interactive mode.""" - from mlx_knife.cache_utils import rm_model - - # Create temp structure (updated for fix) - hub_dir = temp_cache_dir / "hub" - model_dir = hub_dir / "models--test--model" - snapshots_dir = model_dir / "snapshots" - hash_dir = snapshots_dir / "abc12345" - locks_dir = hub_dir / ".locks" / "models--test--model" - - hash_dir.mkdir(parents=True) - locks_dir.mkdir(parents=True) - (locks_dir / "test.lock").touch() - - mock_resolve.return_value = (model_dir, "test/model", None) - # Mock MODEL_CACHE to point to hub directory (after fix) - import mlx_knife.cache_utils - mlx_knife.cache_utils.MODEL_CACHE = hub_dir - - # Test 1: Interactive mode - user says no - mock_rmtree.reset_mock() - with patch('builtins.input', return_value='n'): - rm_model("test/model", force=False) - # Should NOT delete anything when user says no - mock_rmtree.assert_not_called() - - # Test 2: Force mode - no prompts, just delete - mock_rmtree.reset_mock() - with patch('builtins.input') as mock_input: - rm_model("test/model", force=True) - # Should NOT prompt user - mock_input.assert_not_called() - # Should delete both model and locks - assert mock_rmtree.call_count == 2 - - - @patch('mlx_knife.cache_utils.resolve_single_model') - def test_rm_model_not_found(self, mock_resolve): - """Test rm behavior when model is not found.""" - from mlx_knife.cache_utils import rm_model - - # Setup resolve to return None (not found) - mock_resolve.return_value = (None, None, None) - - # Should return early without error - result = rm_model("nonexistent/model@hash") - assert result is None - - -class TestPartialNameFiltering: - """Test partial name filtering for list command (Issue 1).""" - - def test_find_matching_models_function(self): - """Test the find_matching_models helper function.""" - with patch('mlx_knife.cache_utils.MODEL_CACHE') as mock_cache: - # Mock some model directories - mock_models = [ - MagicMock(name="models--mlx-community--Phi-3-mini"), - MagicMock(name="models--mlx-community--Phi-3-medium"), - MagicMock(name="models--other--Llama-3-8B"), - ] - - for i, mock_model in enumerate(mock_models): - mock_model.name = f"models--{'mlx-community' if i < 2 else 'other'}--{'Phi-3-mini' if i == 0 else 'Phi-3-medium' if i == 1 else 'Llama-3-8B'}" - - mock_cache.iterdir.return_value = mock_models - - # Test finding Phi-3 models - matches = find_matching_models("Phi-3") - assert len(matches) == 2 - - # Test finding non-existent model - matches = find_matching_models("nonexistent") - assert len(matches) == 0 - - def test_partial_matching_basic_functionality(self): - """Test basic partial matching logic without complex mocking.""" - # Simple functional test of the helper functions - try: - # These functions exist and can be called - assert callable(find_matching_models) - # Function handles empty input gracefully - matches = find_matching_models("") - assert isinstance(matches, list) - except Exception as e: - pytest.fail(f"Basic functionality test failed: {e}") - - -class TestSingleModelFuzzyMatching: - """Test fuzzy matching for single-model commands (Issue 2).""" - - def test_resolve_single_model_function_exists(self): - """Test that resolve_single_model function exists and is callable.""" - try: - assert callable(resolve_single_model) - # Function handles invalid input gracefully - result = resolve_single_model("definitely-nonexistent-model-12345") - assert isinstance(result, tuple) - assert len(result) == 3 - except Exception as e: - pytest.fail(f"Function existence test failed: {e}") - - @patch('mlx_knife.cache_utils.get_model_path') - @patch('mlx_knife.cache_utils.find_matching_models') - def test_resolve_single_model_ambiguous_fuzzy(self, mock_find, mock_get_path, capsys): - """Test ambiguous fuzzy match shows error.""" - # Mock exact match fails, fuzzy finds multiple matches - mock_get_path.return_value = (None, None, None) - mock_find.return_value = [ - (MagicMock(), "model-1"), - (MagicMock(), "model-2") - ] - - result = resolve_single_model("partial") - assert result[0] is None # Should fail - - # Check that error message was printed - captured = capsys.readouterr() - assert "Multiple models match" in captured.out - assert "model-1" in captured.out - assert "model-2" in captured.out - - @patch('mlx_knife.cache_utils.get_model_path') - @patch('mlx_knife.cache_utils.find_matching_models') - def test_resolve_single_model_no_match(self, mock_find, mock_get_path, capsys): - """Test no match shows appropriate error.""" - # Mock both exact and fuzzy matching fail - mock_get_path.return_value = (None, None, None) - mock_find.return_value = [] - - result = resolve_single_model("nonexistent") - assert result[0] is None # Should fail - - # Check error message - captured = capsys.readouterr() - assert "No models found matching" in captured.out - - -class TestShowModelHealthConsistency: - """Test for Issue #7 - Health check inconsistency in show command with fuzzy model names.""" - - @patch('mlx_knife.cache_utils.resolve_single_model') - @patch('mlx_knife.cache_utils.is_model_healthy') - @patch('mlx_knife.cache_utils.get_model_size') - @patch('mlx_knife.cache_utils.get_model_modified') - @patch('mlx_knife.cache_utils.detect_framework') - @patch('builtins.print') - def test_show_model_health_consistency_fuzzy_vs_full_name(self, mock_print, mock_framework, - mock_modified, mock_size, mock_healthy, - mock_resolve, temp_cache_dir): - """Test that fuzzy and full model names show identical health status. - - This is a regression test for Issue #7 where: - - mlxk show Phi-3 showed "CORRUPTED" - - mlxk show mlx-community/Phi-3-mini-4k-instruct-4bit showed "OK" - for the same underlying model. - """ - # Setup mock model path - mock_model_path = temp_cache_dir / "models--mlx-community--Phi-3-mini-4k-instruct-4bit" / "snapshots" / "abc123" - mock_model_path.mkdir(parents=True) - - # Mock resolve_single_model to return consistent results - # Both fuzzy "Phi-3" and full name should resolve to same model_name - mock_resolve.return_value = ( - mock_model_path, - "mlx-community/Phi-3-mini-4k-instruct-4bit", # Resolved full name - "abc123" - ) - - # Mock other dependencies - mock_size.return_value = "4.2GB" - mock_modified.return_value = "2023-12-01 10:00:00" - mock_framework.return_value = "MLX" - - # Test both healthy and unhealthy scenarios - for health_status in [True, False]: - mock_healthy.return_value = health_status - mock_print.reset_mock() - - # Test fuzzy name - from mlx_knife.cache_utils import show_model - show_model("Phi-3") # Fuzzy name - fuzzy_calls = [str(call) for call in mock_print.call_args_list] - - mock_print.reset_mock() - - # Test full name - show_model("mlx-community/Phi-3-mini-4k-instruct-4bit") # Full name - full_calls = [str(call) for call in mock_print.call_args_list] - - # Both should have identical health output - fuzzy_health_output = [call for call in fuzzy_calls if "Health:" in call] - full_health_output = [call for call in full_calls if "Health:" in call] - - assert len(fuzzy_health_output) == 1, f"Expected 1 health output for fuzzy name, got {len(fuzzy_health_output)}" - assert len(full_health_output) == 1, f"Expected 1 health output for full name, got {len(full_health_output)}" - assert fuzzy_health_output == full_health_output, f"Health status differs: fuzzy={fuzzy_health_output} vs full={full_health_output}" - - # Verify is_model_healthy was called with resolved model name (not original spec) - expected_calls = [call("mlx-community/Phi-3-mini-4k-instruct-4bit")] * 2 - assert mock_healthy.call_args_list == expected_calls, f"is_model_healthy should be called with resolved name, got {mock_healthy.call_args_list}" - - # Reset for next iteration - mock_healthy.reset_mock() - - - -class TestIssue6RepositoryNameValidation: - """Test for Issue #6 - Add repository name length validation for HuggingFace Hub.""" - - @patch('builtins.input', return_value='y') # Mock user input to avoid stdin issues - def test_pull_model_rejects_long_names(self, mock_input, capsys): - """Test that repository names >96 characters are rejected.""" - from mlx_knife.hf_download import pull_model - - # Create a name that exceeds 96 characters after expansion - # Use direct long name that doesn't get expanded but is >96 chars - long_model_name = "organization-name/very-long-model-name-that-definitely-exceeds-the-character-limit-for-repositories-on-hf-platform" - - result = pull_model(long_model_name) - - assert result is False - - captured = capsys.readouterr() - assert "Repository name exceeds HuggingFace Hub limit" in captured.out - assert "96 characters" in captured.out - assert "cannot exist on HuggingFace Hub" in captured.out - - -class TestIssue13HashBasedDisambiguation: - """Test for Issue #13 - Hash-based disambiguation for ambiguous model names.""" - - def test_hash_exists_in_local_cache_full_hash(self): - """Test hash_exists_in_local_cache returns full hash when exact match exists.""" - with patch('mlx_knife.cache_utils.MODEL_CACHE') as mock_cache: - mock_hash_dir = MagicMock() - mock_hash_dir.exists.return_value = True - - mock_snapshots_dir = MagicMock() - mock_snapshots_dir.exists.return_value = True - mock_snapshots_dir.__truediv__.return_value = mock_hash_dir - - mock_base_dir = MagicMock() - mock_base_dir.exists.return_value = True - mock_base_dir.__truediv__.return_value = mock_snapshots_dir - - mock_cache.__truediv__.return_value = mock_base_dir - - from mlx_knife.cache_utils import hash_exists_in_local_cache - - full_hash = "a5339a4131f135d0fdc6a5c8b5bbed2753bbe0f3" - result = hash_exists_in_local_cache("mlx-community/Phi-3-mini", full_hash) - assert result == full_hash - - def test_hash_exists_in_local_cache_none_no_model(self): - """Test hash_exists_in_local_cache returns None when model doesn't exist.""" - with patch('mlx_knife.cache_utils.MODEL_CACHE') as mock_cache: - mock_base_dir = MagicMock() - mock_base_dir.exists.return_value = False - mock_cache.__truediv__.return_value = mock_base_dir - - from mlx_knife.cache_utils import hash_exists_in_local_cache - - result = hash_exists_in_local_cache("nonexistent/model", "hash123") - assert result is None - - def test_hash_exists_in_local_cache_none_no_hash(self): - """Test hash_exists_in_local_cache returns None when hash doesn't exist.""" - with patch('mlx_knife.cache_utils.MODEL_CACHE') as mock_cache: - mock_hash_dir = MagicMock() - mock_hash_dir.exists.return_value = False - - mock_snapshots_dir = MagicMock() - mock_snapshots_dir.exists.return_value = True - mock_snapshots_dir.__truediv__.return_value = mock_hash_dir - mock_snapshots_dir.iterdir.return_value = [] # No snapshots - - mock_base_dir = MagicMock() - mock_base_dir.exists.return_value = True - mock_base_dir.__truediv__.return_value = mock_snapshots_dir - - mock_cache.__truediv__.return_value = mock_base_dir - - from mlx_knife.cache_utils import hash_exists_in_local_cache - - result = hash_exists_in_local_cache("mlx-community/Phi-3-mini", "nonexistent") - assert result is None - - def test_hash_exists_in_local_cache_short_hash_resolution(self): - """Test hash_exists_in_local_cache resolves short hashes locally.""" - with patch('mlx_knife.cache_utils.MODEL_CACHE') as mock_cache: - # Mock exact match fails - mock_hash_dir = MagicMock() - mock_hash_dir.exists.return_value = False - - # Mock snapshots directory with matching hash - mock_snapshot = MagicMock() - mock_snapshot.is_dir.return_value = True - mock_snapshot.name = "de2dfaf56839b7d0e834157d2401dee02726874d" - - mock_snapshots_dir = MagicMock() - mock_snapshots_dir.exists.return_value = True - mock_snapshots_dir.__truediv__.return_value = mock_hash_dir - mock_snapshots_dir.iterdir.return_value = [mock_snapshot] - - mock_base_dir = MagicMock() - mock_base_dir.exists.return_value = True - mock_base_dir.__truediv__.return_value = mock_snapshots_dir - - mock_cache.__truediv__.return_value = mock_base_dir - - from mlx_knife.cache_utils import hash_exists_in_local_cache - - result = hash_exists_in_local_cache("mlx-community/Llama-3.3-70B", "de2dfaf5") - assert result == "de2dfaf56839b7d0e834157d2401dee02726874d" - - @patch('mlx_knife.cache_utils.get_model_path') - @patch('mlx_knife.cache_utils.hash_exists_in_local_cache') - @patch('mlx_knife.cache_utils.find_matching_models') - @patch('mlx_knife.cache_utils.MODEL_CACHE') - def test_resolve_single_model_hash_disambiguation_success(self, mock_cache, mock_find, mock_hash_exists, mock_get_path): - """Test successful hash-based disambiguation when multiple models match.""" - # Mock find_matching_models to return multiple matches - mock_find.return_value = [ - (MagicMock(), "mlx-community/Llama-3.2-1B-Instruct-4bit"), - (MagicMock(), "mlx-community/Llama-3.3-70B-Instruct-4bit"), - ] - - # Mock hash_exists_in_local_cache to return full hash for second model only - def mock_hash_exists_side_effect(model_name, commit_hash): - if model_name == "mlx-community/Llama-3.3-70B-Instruct-4bit": - return "de2dfaf56839b7d0e834157d2401dee02726874d" - return None - mock_hash_exists.side_effect = mock_hash_exists_side_effect - - # Mock get_model_path to return success - mock_get_path.return_value = (MagicMock(), "mlx-community/Llama-3.3-70B-Instruct-4bit", "de2dfaf5") - - # Mock MODEL_CACHE behavior for exact match check - mock_base_dir = MagicMock() - mock_base_dir.exists.return_value = False - mock_cache.__truediv__.return_value = mock_base_dir - - from mlx_knife.cache_utils import resolve_single_model - - result = resolve_single_model("Llama@de2dfaf5") - - # Should successfully resolve to the second model - assert result[1] == "mlx-community/Llama-3.3-70B-Instruct-4bit" - assert result[2] == "de2dfaf5" - - # Verify hash_exists_in_local_cache was called for both models - assert mock_hash_exists.call_count == 2 - - # Verify get_model_path was called with the resolved spec (full hash) - mock_get_path.assert_called_once_with("mlx-community/Llama-3.3-70B-Instruct-4bit@de2dfaf56839b7d0e834157d2401dee02726874d") - - @patch('mlx_knife.cache_utils.hash_exists_in_local_cache') - @patch('mlx_knife.cache_utils.find_matching_models') - @patch('mlx_knife.cache_utils.MODEL_CACHE') - def test_resolve_single_model_hash_disambiguation_no_match(self, mock_cache, mock_find, mock_hash_exists, capsys): - """Test hash-based disambiguation when hash doesn't exist in any model.""" - # Mock find_matching_models to return multiple matches - mock_find.return_value = [ - (MagicMock(), "mlx-community/Llama-3.2-1B-Instruct-4bit"), - (MagicMock(), "mlx-community/Llama-3.3-70B-Instruct-4bit"), - ] - - # Mock hash_exists_in_local_cache to return None for all models - mock_hash_exists.return_value = None - - # Mock MODEL_CACHE behavior for exact match check - mock_base_dir = MagicMock() - mock_base_dir.exists.return_value = False - mock_cache.__truediv__.return_value = mock_base_dir - - from mlx_knife.cache_utils import resolve_single_model - - result = resolve_single_model("Llama@nonexistent") - - # Should return None tuple - assert result == (None, None, None) - - # Check error message was printed - captured = capsys.readouterr() - assert "Hash 'nonexistent' not found in any model matching 'Llama'" in captured.out - assert "Available models:" in captured.out - - @patch('mlx_knife.cache_utils.find_matching_models') - @patch('mlx_knife.cache_utils.MODEL_CACHE') - def test_resolve_single_model_no_hash_multiple_matches(self, mock_cache, mock_find, capsys): - """Test traditional ambiguous model behavior without hash is preserved.""" - # Mock find_matching_models to return multiple matches - mock_find.return_value = [ - (MagicMock(), "mlx-community/Llama-3.2-1B-Instruct-4bit"), - (MagicMock(), "mlx-community/Llama-3.3-70B-Instruct-4bit"), - ] - - # Mock MODEL_CACHE behavior for exact match check - mock_base_dir = MagicMock() - mock_base_dir.exists.return_value = False - mock_cache.__truediv__.return_value = mock_base_dir - - from mlx_knife.cache_utils import resolve_single_model - - result = resolve_single_model("Llama") # No hash specified - - # Should return None tuple - assert result == (None, None, None) - - # Check traditional error message was printed - captured = capsys.readouterr() - assert "Multiple models match 'Llama'. Please be more specific:" in captured.out - - -# Add pytest fixture at module level -@pytest.fixture -def temp_cache_dir(): - """Create temporary cache directory for testing.""" - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) \ No newline at end of file diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py deleted file mode 100644 index 9082064..0000000 --- a/tests/unit/test_cli.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -Unit tests for cli.py module. - -Tests the command-line interface functionality: -- Argument parsing -- Command dispatch -- Help and version output -""" -import pytest -import argparse -from unittest.mock import patch, MagicMock -import sys -import os - -# Import the module under test -from mlx_knife.cli import main - - -class TestMainFunctionBasic: - """Test basic main function behavior without requiring parser creation.""" - - def test_main_function_exists(self): - """Test that main function exists and is callable.""" - try: - assert callable(main) - except Exception as e: - pytest.fail(f"Main function test failed: {e}") - - def test_version_flag_via_main(self): - """Test version flag through main function.""" - try: - with patch('sys.argv', ['mlxk', '--version']): - with pytest.raises(SystemExit) as exc_info: - main() - # Version should exit cleanly - assert exc_info.value.code in [0, None] - except Exception as e: - # It's OK if version parsing isn't fully implemented yet - pass - - -class TestMainFunction: - """Test main function behavior.""" - - def test_main_with_help(self): - """Test main function with help argument.""" - try: - with patch('sys.argv', ['mlxk', '--help']): - with pytest.raises(SystemExit) as exc_info: - main() - # Help should exit with code 0 - assert exc_info.value.code == 0 or exc_info.value.code is None - except Exception as e: - pytest.fail(f"Main function help test failed: {e}") - - def test_main_with_invalid_command(self): - """Test main function with invalid command.""" - try: - with patch('sys.argv', ['mlxk', 'invalid-command-xyz']): - with pytest.raises(SystemExit) as exc_info: - main() - # Invalid command should exit with non-zero code - assert exc_info.value.code != 0 - except Exception as e: - pytest.fail(f"Main function invalid command test failed: {e}") - - @patch('mlx_knife.cache_utils.list_models') - def test_main_with_list_command(self, mock_list_models): - """Test main function with list command.""" - try: - # Mock the list_models function to avoid actual cache interaction - mock_list_models.return_value = None - - with patch('sys.argv', ['mlxk', 'list']): - try: - main() - except SystemExit as e: - # List command might exit with 0 on success - assert e.code == 0 or e.code is None - except Exception as e: - pytest.fail(f"Main function list command test failed: {e}") - - @patch('mlx_knife.cache_utils.check_all_models_health') - def test_main_with_health_command(self, mock_health_check): - """Test main function with health command.""" - try: - # Mock the health check function - mock_health_check.return_value = None - - with patch('sys.argv', ['mlxk', 'health']): - try: - main() - except SystemExit as e: - # Health command should exit gracefully - assert e.code == 0 or e.code is None - except Exception as e: - pytest.fail(f"Main function health command test failed: {e}") - - def test_main_no_arguments(self): - """Test main function with no arguments.""" - try: - with patch('sys.argv', ['mlxk']): - # The CLI shows help when no args are provided - this is valid behavior - main() # Should complete successfully showing help - except SystemExit as e: - # Also valid - some CLIs exit after showing help - pass - except Exception as e: - pytest.fail(f"Main function no arguments test failed: {e}") - - -class TestErrorHandling: - """Test CLI error handling.""" - - def test_keyboard_interrupt_handling(self): - """Test handling of KeyboardInterrupt (Ctrl+C).""" - try: - # Test that KeyboardInterrupt doesn't crash the CLI completely - with patch('sys.argv', ['mlxk', 'list']): - with patch('builtins.print', side_effect=KeyboardInterrupt()): - try: - main() - except KeyboardInterrupt: - # KeyboardInterrupt propagating up is acceptable - pass - except SystemExit: - # Graceful exit is also acceptable - pass - except Exception as e: - pytest.fail(f"Keyboard interrupt handling test failed: {e}") - - def test_basic_command_robustness(self): - """Test that basic commands don't crash unexpectedly.""" - try: - # Test that list command runs successfully (already working based on earlier test) - with patch('sys.argv', ['mlxk', 'list']): - main() # Should work fine - except SystemExit: - # Exit is acceptable for some CLI implementations - pass - except Exception as e: - pytest.fail(f"Basic command robustness test failed: {e}") - - -class TestHealthCommandDefaultBehavior: - """Test health command default behavior (Issue 3).""" - - @patch('mlx_knife.cli.check_all_models_health') - def test_health_command_without_args_calls_all(self, mock_check_all): - """Test that 'mlxk health' (no args) calls check_all_models_health.""" - mock_check_all.return_value = True - - try: - with patch('sys.argv', ['mlxk', 'health']): - main() - - # Should have called check_all_models_health - assert mock_check_all.called - mock_check_all.assert_called_once() - except SystemExit: - # Exit is acceptable after running the command - assert mock_check_all.called - except Exception as e: - pytest.fail(f"Health command default behavior test failed: {e}") - - @patch('mlx_knife.cli.check_model_health') - @patch('mlx_knife.cli.check_all_models_health') - def test_health_command_with_specific_model(self, mock_check_all, mock_check_specific): - """Test that 'mlxk health model-name' calls check_model_health.""" - mock_check_specific.return_value = True - - try: - with patch('sys.argv', ['mlxk', 'health', 'some-model']): - main() - - # Should have called check_model_health with the specific model - assert mock_check_specific.called - mock_check_specific.assert_called_once_with('some-model') - - # Should NOT have called check_all_models_health - assert not mock_check_all.called - except SystemExit: - # Exit is acceptable after running the command - assert mock_check_specific.called - assert not mock_check_all.called - except Exception as e: - pytest.fail(f"Health command specific model test failed: {e}") - - @patch('mlx_knife.cli.check_all_models_health') - def test_health_command_backward_compatibility_with_all_flag(self, mock_check_all): - """Test that 'mlxk health --all' still works for backward compatibility.""" - mock_check_all.return_value = True - - try: - with patch('sys.argv', ['mlxk', 'health', '--all']): - main() - - # Should have called check_all_models_health - assert mock_check_all.called - mock_check_all.assert_called_once() - except SystemExit: - # Exit is acceptable after running the command - assert mock_check_all.called - except Exception as e: - pytest.fail(f"Health command --all flag test failed: {e}") \ No newline at end of file diff --git a/tests/unit/test_mlx_runner_memory.py b/tests/unit/test_mlx_runner_memory.py deleted file mode 100644 index 3dcdc88..0000000 --- a/tests/unit/test_mlx_runner_memory.py +++ /dev/null @@ -1,551 +0,0 @@ -""" -Unit tests for MLXRunner memory management robustness and context length handling. - -Tests context manager implementation, exception handling, cleanup guarantees, -and model context length extraction without requiring actual MLX models. -""" -import json -import os -import tempfile -import unittest -from unittest.mock import MagicMock, patch, PropertyMock -import gc - - -class TestMLXRunnerMemoryManagement(unittest.TestCase): - """Test MLXRunner memory management robustness.""" - - @patch('mlx_knife.mlx_runner.mx') - @patch('mlx_knife.mlx_runner.load') - def test_context_manager_basic_flow(self, mock_load, mock_mx): - """Test basic context manager flow with successful execution.""" - from mlx_knife.mlx_runner import MLXRunner - - # Setup mocks - mock_model = MagicMock() - mock_tokenizer = MagicMock() - mock_tokenizer.eos_token = '' - mock_tokenizer.eos_token_id = 2 - mock_load.return_value = (mock_model, mock_tokenizer) - mock_mx.get_active_memory.return_value = 1024 * 1024 * 1024 # 1GB - - # Test successful context manager usage - with MLXRunner("test_model", verbose=False) as runner: - self.assertIsNotNone(runner.model) - self.assertIsNotNone(runner.tokenizer) - self.assertTrue(runner._model_loaded) - self.assertTrue(runner._context_entered) - - # After exiting context, model should be cleaned up - self.assertIsNone(runner.model) - self.assertIsNone(runner.tokenizer) - self.assertFalse(runner._model_loaded) - self.assertFalse(runner._context_entered) - - # Verify cleanup was called - mock_mx.clear_cache.assert_called() - - @patch('mlx_knife.mlx_runner.mx') - @patch('mlx_knife.mlx_runner.load') - def test_context_manager_exception_in_load(self, mock_load, mock_mx): - """Test cleanup when exception occurs during model loading.""" - from mlx_knife.mlx_runner import MLXRunner - - # Setup mock to fail during load - mock_load.side_effect = RuntimeError("Model loading failed") - mock_mx.get_active_memory.return_value = 1024 * 1024 * 1024 - - # Test that exception is propagated and cleanup happens - with self.assertRaises(RuntimeError) as cm: - with MLXRunner("test_model", verbose=False) as runner: - pass # Should never reach here - - self.assertIn("Failed to load model", str(cm.exception)) - - # Verify cleanup was called even on failure - mock_mx.clear_cache.assert_called() - - @patch('mlx_knife.mlx_runner.mx') - @patch('mlx_knife.mlx_runner.load') - def test_context_manager_exception_in_body(self, mock_load, mock_mx): - """Test cleanup when exception occurs in context body.""" - from mlx_knife.mlx_runner import MLXRunner - - # Setup successful mocks - mock_model = MagicMock() - mock_tokenizer = MagicMock() - mock_tokenizer.eos_token = '' - mock_tokenizer.eos_token_id = 2 - mock_load.return_value = (mock_model, mock_tokenizer) - mock_mx.get_active_memory.return_value = 1024 * 1024 * 1024 - - # Test exception in context body - with self.assertRaises(ValueError): - with MLXRunner("test_model", verbose=False) as runner: - self.assertTrue(runner._model_loaded) - raise ValueError("User error") - - # Cleanup should still happen - self.assertIsNone(runner.model) - self.assertIsNone(runner.tokenizer) - self.assertFalse(runner._model_loaded) - mock_mx.clear_cache.assert_called() - - @patch('mlx_knife.mlx_runner.mx') - @patch('mlx_knife.mlx_runner.load') - def test_prevent_nested_context_usage(self, mock_load, mock_mx): - """Test that nested context manager usage is prevented.""" - from mlx_knife.mlx_runner import MLXRunner - - # Setup mocks - mock_model = MagicMock() - mock_tokenizer = MagicMock() - mock_tokenizer.eos_token = '' - mock_tokenizer.eos_token_id = 2 - mock_load.return_value = (mock_model, mock_tokenizer) - mock_mx.get_active_memory.return_value = 1024 * 1024 * 1024 - - runner = MLXRunner("test_model", verbose=False) - - # First context should work - with runner: - self.assertTrue(runner._context_entered) - - # Nested context should fail - with self.assertRaises(RuntimeError) as cm: - with runner: - pass - - self.assertIn("cannot be entered multiple times", str(cm.exception)) - - # After exiting, should be able to use again - self.assertFalse(runner._context_entered) - - # Second usage should work - with runner: - self.assertTrue(runner._context_entered) - - @patch('mlx_knife.mlx_runner.mx') - @patch('mlx_knife.mlx_runner.load') - def test_partial_loading_failure_cleanup(self, mock_load, mock_mx): - """Test cleanup when loading partially succeeds then fails.""" - from mlx_knife.mlx_runner import MLXRunner - - # Setup mock to partially succeed - mock_model = MagicMock() - mock_tokenizer = MagicMock() - - # Missing required attributes to trigger failure in _extract_stop_tokens - del mock_tokenizer.eos_token - del mock_tokenizer.eos_token_id - mock_tokenizer.encode.side_effect = Exception("Tokenizer error") - - mock_load.return_value = (mock_model, mock_tokenizer) - mock_mx.get_active_memory.return_value = 1024 * 1024 * 1024 - - runner = MLXRunner("test_model", verbose=False) - - # Load should succeed even with tokenizer issues - try: - runner.load_model() - # Model should be loaded even if stop token extraction had issues - self.assertIsNotNone(runner.model) - self.assertIsNotNone(runner.tokenizer) - finally: - # Cleanup should work regardless - runner.cleanup() - self.assertIsNone(runner.model) - self.assertIsNone(runner.tokenizer) - mock_mx.clear_cache.assert_called() - - @patch('mlx_knife.mlx_runner.mx') - def test_cleanup_idempotency(self, mock_mx): - """Test that cleanup can be called multiple times safely.""" - from mlx_knife.mlx_runner import MLXRunner - - mock_mx.get_active_memory.return_value = 1024 * 1024 * 1024 - - runner = MLXRunner("test_model", verbose=False) - runner.model = MagicMock() - runner.tokenizer = MagicMock() - runner._model_loaded = True - - # Call cleanup multiple times - for _ in range(3): - runner.cleanup() - self.assertIsNone(runner.model) - self.assertIsNone(runner.tokenizer) - self.assertFalse(runner._model_loaded) - - # Should have been called at least once - mock_mx.clear_cache.assert_called() - - @patch('mlx_knife.mlx_runner.mx') - @patch('mlx_knife.mlx_runner.load') - def test_memory_baseline_tracking(self, mock_load, mock_mx): - """Test memory baseline is properly tracked.""" - from mlx_knife.mlx_runner import MLXRunner - - # Setup mocks - mock_model = MagicMock() - mock_tokenizer = MagicMock() - mock_tokenizer.eos_token = '' - mock_tokenizer.eos_token_id = 2 - mock_load.return_value = (mock_model, mock_tokenizer) - - # Simulate memory growth during loading - memory_values = [ - 1 * 1024**3, # 1GB baseline - 5 * 1024**3, # 5GB after loading - 5 * 1024**3, # 5GB when querying stats - ] - mock_mx.get_active_memory.side_effect = memory_values - - runner = MLXRunner("test_model", verbose=False) - runner.load_model() - - # Check baseline was captured - self.assertEqual(runner._memory_baseline, 1.0) # 1GB - - # Check memory usage calculation - memory_stats = runner.get_memory_usage() - self.assertEqual(memory_stats["model_gb"], 4.0) # 5GB - 1GB = 4GB - - @patch('mlx_knife.mlx_runner.mx') - @patch('mlx_knife.mlx_runner.load') - def test_generate_without_loading(self, mock_load, mock_mx): - """Test that generate methods fail gracefully without loaded model.""" - from mlx_knife.mlx_runner import MLXRunner - - runner = MLXRunner("test_model", verbose=False) - - # Try to generate without loading - with self.assertRaises(RuntimeError) as cm: - list(runner.generate_streaming("test prompt")) - self.assertIn("Model not loaded", str(cm.exception)) - - with self.assertRaises(RuntimeError) as cm: - runner.generate_batch("test prompt") - self.assertIn("Model not loaded", str(cm.exception)) - - @patch('mlx_knife.mlx_runner.mx') - @patch('mlx_knife.mlx_runner.load') - def test_server_usage_without_context_manager(self, mock_load, mock_mx): - """Test server-style usage without context manager.""" - from mlx_knife.mlx_runner import MLXRunner - - # Setup mocks - mock_model = MagicMock() - mock_tokenizer = MagicMock() - mock_tokenizer.eos_token = '' - mock_tokenizer.eos_token_id = 2 - mock_load.return_value = (mock_model, mock_tokenizer) - mock_mx.get_active_memory.return_value = 1024 * 1024 * 1024 - - # Server style: manual load and cleanup - runner = MLXRunner("test_model", verbose=False) - - try: - runner.load_model() - self.assertTrue(runner._model_loaded) - self.assertIsNotNone(runner.model) - - # Simulate server keeping model loaded - # and potentially switching models - runner.cleanup() - self.assertFalse(runner._model_loaded) - self.assertIsNone(runner.model) - - # Load again (simulating model switch) - runner.load_model() - self.assertTrue(runner._model_loaded) - - finally: - # Ensure cleanup happens - runner.cleanup() - self.assertFalse(runner._model_loaded) - - @patch('mlx_knife.mlx_runner.mx') - @patch('mlx_knife.mlx_runner.load') - def test_exception_during_cleanup(self, mock_load, mock_mx): - """Test that cleanup handles exceptions gracefully.""" - from mlx_knife.mlx_runner import MLXRunner - - # Setup mocks - mock_model = MagicMock() - mock_tokenizer = MagicMock() - mock_tokenizer.eos_token = '' - mock_tokenizer.eos_token_id = 2 - mock_load.return_value = (mock_model, mock_tokenizer) - mock_mx.get_active_memory.return_value = 1024 * 1024 * 1024 - - # Make clear_cache raise an exception - mock_mx.clear_cache.side_effect = Exception("Cache clear failed") - - runner = MLXRunner("test_model", verbose=False) - runner.load_model() - - # Cleanup should complete even if mx.clear_cache fails - runner.cleanup() # Should not raise - - # State should still be cleaned - self.assertIsNone(runner.model) - self.assertIsNone(runner.tokenizer) - self.assertFalse(runner._model_loaded) - - -class TestModelContextLength(unittest.TestCase): - """Test model context length extraction functionality.""" - - def test_get_model_context_length_with_max_position_embeddings(self): - """Test context length extraction from max_position_embeddings.""" - from mlx_knife.mlx_runner import get_model_context_length - - with tempfile.TemporaryDirectory() as temp_dir: - config_path = os.path.join(temp_dir, "config.json") - config = { - "max_position_embeddings": 4096, - "hidden_size": 768, - "num_attention_heads": 12 - } - - with open(config_path, 'w') as f: - json.dump(config, f) - - context_length = get_model_context_length(temp_dir) - self.assertEqual(context_length, 4096) - - def test_get_model_context_length_with_n_positions(self): - """Test context length extraction from n_positions (GPT-style).""" - from mlx_knife.mlx_runner import get_model_context_length - - with tempfile.TemporaryDirectory() as temp_dir: - config_path = os.path.join(temp_dir, "config.json") - config = { - "n_positions": 2048, - "n_embd": 512, - "n_head": 8 - } - - with open(config_path, 'w') as f: - json.dump(config, f) - - context_length = get_model_context_length(temp_dir) - self.assertEqual(context_length, 2048) - - def test_get_model_context_length_with_context_length(self): - """Test context length extraction from context_length field.""" - from mlx_knife.mlx_runner import get_model_context_length - - with tempfile.TemporaryDirectory() as temp_dir: - config_path = os.path.join(temp_dir, "config.json") - config = { - "context_length": 8192, - "hidden_size": 1024 - } - - with open(config_path, 'w') as f: - json.dump(config, f) - - context_length = get_model_context_length(temp_dir) - self.assertEqual(context_length, 8192) - - def test_get_model_context_length_with_max_sequence_length(self): - """Test context length extraction from max_sequence_length.""" - from mlx_knife.mlx_runner import get_model_context_length - - with tempfile.TemporaryDirectory() as temp_dir: - config_path = os.path.join(temp_dir, "config.json") - config = { - "max_sequence_length": 32768, - "d_model": 2048 - } - - with open(config_path, 'w') as f: - json.dump(config, f) - - context_length = get_model_context_length(temp_dir) - self.assertEqual(context_length, 32768) - - def test_get_model_context_length_with_seq_len(self): - """Test context length extraction from seq_len field.""" - from mlx_knife.mlx_runner import get_model_context_length - - with tempfile.TemporaryDirectory() as temp_dir: - config_path = os.path.join(temp_dir, "config.json") - config = { - "seq_len": 16384, - "embedding_size": 1536 - } - - with open(config_path, 'w') as f: - json.dump(config, f) - - context_length = get_model_context_length(temp_dir) - self.assertEqual(context_length, 16384) - - def test_get_model_context_length_priority_order(self): - """Test that max_position_embeddings takes priority over other fields.""" - from mlx_knife.mlx_runner import get_model_context_length - - with tempfile.TemporaryDirectory() as temp_dir: - config_path = os.path.join(temp_dir, "config.json") - config = { - "max_position_embeddings": 4096, # Should be used (first in priority) - "n_positions": 2048, - "context_length": 8192, - "max_sequence_length": 16384, - "seq_len": 1024 - } - - with open(config_path, 'w') as f: - json.dump(config, f) - - context_length = get_model_context_length(temp_dir) - self.assertEqual(context_length, 4096) - - def test_get_model_context_length_missing_config_file(self): - """Test default context length when config.json is missing.""" - from mlx_knife.mlx_runner import get_model_context_length - - with tempfile.TemporaryDirectory() as temp_dir: - # No config.json file created - context_length = get_model_context_length(temp_dir) - self.assertEqual(context_length, 4096) # Default fallback - - def test_get_model_context_length_invalid_json(self): - """Test default context length when config.json is malformed.""" - from mlx_knife.mlx_runner import get_model_context_length - - with tempfile.TemporaryDirectory() as temp_dir: - config_path = os.path.join(temp_dir, "config.json") - - # Write invalid JSON - with open(config_path, 'w') as f: - f.write("{ invalid json content") - - context_length = get_model_context_length(temp_dir) - self.assertEqual(context_length, 4096) # Default fallback - - def test_get_model_context_length_empty_config(self): - """Test default context length when config.json has no context fields.""" - from mlx_knife.mlx_runner import get_model_context_length - - with tempfile.TemporaryDirectory() as temp_dir: - config_path = os.path.join(temp_dir, "config.json") - config = { - "hidden_size": 768, - "num_attention_heads": 12, - "model_type": "test_model" - } - - with open(config_path, 'w') as f: - json.dump(config, f) - - context_length = get_model_context_length(temp_dir) - self.assertEqual(context_length, 4096) # Default fallback - - -class TestMLXRunnerContextAwareLimits(unittest.TestCase): - """Test MLXRunner context-aware token limits.""" - - @patch('mlx_knife.mlx_runner.get_model_context_length') - def test_get_effective_max_tokens_interactive_mode(self, mock_get_context): - """Test effective max tokens in interactive mode (uses full context).""" - from mlx_knife.mlx_runner import MLXRunner - - mock_get_context.return_value = 4096 - - runner = MLXRunner("test_model", verbose=False) - runner._context_length = 4096 - - # Interactive mode: should use full context length - effective = runner.get_effective_max_tokens(8000, interactive=True) - self.assertEqual(effective, 4096) # Limited by model context - - effective = runner.get_effective_max_tokens(2000, interactive=True) - self.assertEqual(effective, 2000) # User request is smaller - - @patch('mlx_knife.mlx_runner.get_model_context_length') - def test_get_effective_max_tokens_server_mode(self, mock_get_context): - """Test effective max tokens in server mode (uses half context for DoS protection).""" - from mlx_knife.mlx_runner import MLXRunner - - mock_get_context.return_value = 4096 - - runner = MLXRunner("test_model", verbose=False) - runner._context_length = 4096 - - # Server mode: should use half context length - effective = runner.get_effective_max_tokens(8000, interactive=False) - self.assertEqual(effective, 2048) # Limited by server limit (4096 / 2) - - effective = runner.get_effective_max_tokens(1000, interactive=False) - self.assertEqual(effective, 1000) # User request is smaller - - @patch('mlx_knife.mlx_runner.get_model_context_length') - def test_get_effective_max_tokens_no_context_length(self, mock_get_context): - """Test effective max tokens when context length is unknown.""" - from mlx_knife.mlx_runner import MLXRunner - - runner = MLXRunner("test_model", verbose=False) - runner._context_length = None # Context length unknown - - # Should fallback to requested tokens - effective = runner.get_effective_max_tokens(1500, interactive=True) - self.assertEqual(effective, 1500) - - effective = runner.get_effective_max_tokens(2500, interactive=False) - self.assertEqual(effective, 2500) - - @patch('mlx_knife.mlx_runner.get_model_context_length') - def test_get_effective_max_tokens_none_interactive_mode(self, mock_get_context): - """Test that None (no --max-tokens) uses full context in interactive mode.""" - from mlx_knife.mlx_runner import MLXRunner - - mock_get_context.return_value = 4096 - - runner = MLXRunner("test_model", verbose=False) - runner._context_length = 4096 - - # None (user didn't specify --max-tokens) should use full context - effective = runner.get_effective_max_tokens(None, interactive=True) - self.assertEqual(effective, 4096) - - # Explicit values should still be respected - effective = runner.get_effective_max_tokens(500, interactive=True) - self.assertEqual(effective, 500) # Now 500 is treated as explicit user choice - - @patch('mlx_knife.mlx_runner.get_model_context_length') - def test_get_effective_max_tokens_none_server_mode(self, mock_get_context): - """Test that None uses server default in server mode.""" - from mlx_knife.mlx_runner import MLXRunner - - mock_get_context.return_value = 4096 - - runner = MLXRunner("test_model", verbose=False) - runner._context_length = 4096 - - # None in server mode should use server limit (context / 2) - effective = runner.get_effective_max_tokens(None, interactive=False) - self.assertEqual(effective, 2048) # 4096 / 2 - - @patch('mlx_knife.mlx_runner.get_model_context_length') - def test_get_effective_max_tokens_none_unknown_context(self, mock_get_context): - """Test None behavior when context length is unknown.""" - from mlx_knife.mlx_runner import MLXRunner - - runner = MLXRunner("test_model", verbose=False) - runner._context_length = None - - # Interactive mode: should use 4096 fallback when None - effective = runner.get_effective_max_tokens(None, interactive=True) - self.assertEqual(effective, 4096) - - # Server mode: should use 2048 fallback when None - effective = runner.get_effective_max_tokens(None, interactive=False) - self.assertEqual(effective, 2048) - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/tests_2.0/live/test_list_human_live.py b/tests_2.0/live/test_list_human_live.py new file mode 100644 index 0000000..ae97b01 --- /dev/null +++ b/tests_2.0/live/test_list_human_live.py @@ -0,0 +1,87 @@ +"""Opt-in live E2E test for human list rendering using the real HF cache. + +This test is skipped by default. Enable by setting: +- MLXK2_LIVE_LIST=1 +- HF_HOME must point to your Hugging Face cache (read-only) + +It validates that: +- Default list shows only MLX chat models (hides MLX base) +- list --verbose shows all MLX (chat + base) +- list --all shows all frameworks +""" + +from __future__ import annotations + +import json +import os +import sys +from typing import List, Dict + +import pytest + +pytestmark = [pytest.mark.wet, pytest.mark.live_list] + + +def _run_cli(argv: List[str], capsys) -> str: + from mlxk2.cli import main as cli_main + old_argv = sys.argv[:] + sys.argv = argv[:] + try: + with pytest.raises(SystemExit): + cli_main() + finally: + sys.argv = old_argv + out = capsys.readouterr().out + return out + + +def _json_models(capsys) -> List[Dict]: + out = _run_cli(["mlxk2", "list", "--json"], capsys) + data = json.loads(out) + assert data["status"] == "success" and data["command"] == "list" + return data["data"]["models"] + + +def _display_name_for_default(name: str) -> str: + # In compact default view, we strip mlx-community/ prefix + return name.split("/", 1)[1] if name.startswith("mlx-community/") else name + + +def test_live_list_human_variants(capsys, request): + # Only run when explicitly selected with -m live_list + selected = request.config.getoption("-m") or "" + if "live_list" not in selected: + pytest.skip("Run with -m live_list to enable this end-to-end test") + models = _json_models(capsys) + + mlx = [m for m in models if m.get("framework") == "MLX"] + mlx_chat = [m for m in mlx if m.get("model_type") == "chat"] + mlx_base = [m for m in mlx if m.get("model_type") == "base"] + other = [m for m in models if m.get("framework") != "MLX"] + + # Fail if the cache doesn't have the necessary models + assert mlx_chat, "Need at least one MLX chat model in HF cache" + assert mlx_base, "Need at least one MLX base model in HF cache" + + chat_name = mlx_chat[0]["name"] + base_name = mlx_base[0]["name"] + + # Default list: only MLX chat + out_default = _run_cli(["mlxk2", "list"], capsys) + assert _display_name_for_default(chat_name) in out_default + assert _display_name_for_default(base_name) not in out_default + + # Verbose: all MLX (chat + base) + out_verbose = _run_cli(["mlxk2", "list", "--verbose"], capsys) + assert chat_name in out_verbose + assert base_name in out_verbose + + # All: all frameworks + out_all = _run_cli(["mlxk2", "list", "--all"], capsys) + assert _display_name_for_default(chat_name) in out_all or chat_name in out_all + assert _display_name_for_default(base_name) in out_all or base_name in out_all + + if other: + other_name = other[0]["name"] + # Non-MLX names are never stripped by default rule + assert other_name in out_all diff --git a/tests_2.0/test_detection_readme_tokenizer.py b/tests_2.0/test_detection_readme_tokenizer.py new file mode 100644 index 0000000..81ea82a --- /dev/null +++ b/tests_2.0/test_detection_readme_tokenizer.py @@ -0,0 +1,87 @@ +"""Tests for lenient MLX detection (Issue #31 port) in 2.0. + +Covers: +- Framework=MLX via README front-matter (tags/library_name) for non-mlx-community repos. +- Type=chat via tokenizer chat_template hints. +- Consistency between list and show outputs. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Tuple + +from mlxk2.core.cache import hf_to_cache_dir +from mlxk2.operations.list import list_models +from mlxk2.operations.show import show_model_operation + + +def _mk_snapshot(cache_hub: Path, repo_id: str, hash40: str) -> Tuple[Path, Path]: + base = cache_hub / hf_to_cache_dir(repo_id) + snap = base / "snapshots" / hash40 + snap.mkdir(parents=True, exist_ok=True) + # Minimal healthy files + (snap / "config.json").write_text('{"model_type": "test"}', encoding="utf-8") + (snap / "model.safetensors").write_bytes(b"w" * 1024) + return base, snap + + +def test_framework_mlx_from_front_matter(isolated_cache): + repo = "custom-org/FrontMatter-Model" + h = "0123456789abcdef0123456789abcdef01234567" + base, snap = _mk_snapshot(isolated_cache, repo, h) + + # README front-matter indicating MLX + (snap / "README.md").write_text( + """--- +library_name: mlx +tags: [mlx, chat] +--- + +# Dummy +""", + encoding="utf-8", + ) + + out = list_models() + models = {m["name"]: m for m in out["data"]["models"]} + assert repo in models, f"Model not listed: {repo}" + assert models[repo]["framework"] == "MLX" + + s = show_model_operation(repo) + assert s["status"] == "success" + assert s["data"]["model"]["framework"] == "MLX" + + +def test_type_chat_from_tokenizer_chat_template(isolated_cache): + repo = "custom-org/Tokenizer-Chat-Model" + h = "89abcdef0123456789abcdef0123456789abcdef" + base, snap = _mk_snapshot(isolated_cache, repo, h) + + # No chat/instruct in name → rely on tokenizer chat_template + (snap / "tokenizer_config.json").write_text( + '{"chat_template": "{{ bos_token }}{{ eos_token }}"}', encoding="utf-8" + ) + + # Also put a front-matter not mentioning mlx to ensure chat comes from tokenizer + (snap / "README.md").write_text( + """--- +tags: [test] +--- +""", + encoding="utf-8", + ) + + out = list_models() + models = {m["name"]: m for m in out["data"]["models"]} + assert repo in models, f"Model not listed: {repo}" + m = models[repo] + assert m["model_type"] == "chat" + assert "chat" in (m.get("capabilities") or []) + + s = show_model_operation(repo) + assert s["status"] == "success" + ms = s["data"]["model"] + assert ms["model_type"] == "chat" + assert "chat" in (ms.get("capabilities") or []) + diff --git a/tests_2.0/test_human_output.py b/tests_2.0/test_human_output.py index 92c5f4b..d597fb5 100644 --- a/tests_2.0/test_human_output.py +++ b/tests_2.0/test_human_output.py @@ -80,3 +80,97 @@ def test_health_human_summary_and_entries(): assert "model-a" in out assert "model-b" in out + +def test_list_human_filters_mlx_base_default(): + from mlxk2.output.human import render_list + + data = { + "status": "success", + "command": "list", + "data": { + "models": [ + { + "name": "org/MLXChat", + "hash": "abcdef0123456789abcdef0123456789abcdef01", + "size_bytes": 1000, + "last_modified": "2025-08-30T12:00:00Z", + "framework": "MLX", + "model_type": "chat", + "capabilities": ["text-generation", "chat"], + "health": "healthy", + "cached": True, + }, + { + "name": "org/MLXBase", + "hash": "abcdef0123456789abcdef0123456789abcdef02", + "size_bytes": 2000, + "last_modified": "2025-08-30T12:00:00Z", + "framework": "MLX", + "model_type": "base", + "capabilities": ["text-generation"], + "health": "healthy", + "cached": True, + }, + ], + "count": 2, + }, + "error": None, + } + + # Default (compact) should hide MLX base + out_default = render_list(data, show_health=False, show_all=False, verbose=False) + assert "MLXChat" in out_default + assert "MLXBase" not in out_default + + # Verbose (without --all) shows all MLX (chat + base) + out_verbose = render_list(data, show_health=False, show_all=False, verbose=True) + assert "MLXChat" in out_verbose + assert "MLXBase" in out_verbose + + +def test_list_human_verbose_shows_all_mlx_only(): + from mlxk2.output.human import render_list + + data = { + "status": "success", + "command": "list", + "data": { + "models": [ + {"name": "org/MLXChat", "hash": None, "size_bytes": 1, "last_modified": "2025-08-30T12:00:00Z", "framework": "MLX", "model_type": "chat", "capabilities": ["text-generation", "chat"], "health": "healthy", "cached": True}, + {"name": "org/MLXBase", "hash": None, "size_bytes": 1, "last_modified": "2025-08-30T12:00:00Z", "framework": "MLX", "model_type": "base", "capabilities": ["text-generation"], "health": "healthy", "cached": True}, + {"name": "org/OtherPT", "hash": None, "size_bytes": 1, "last_modified": "2025-08-30T12:00:00Z", "framework": "PyTorch", "model_type": "base", "capabilities": ["text-generation"], "health": "healthy", "cached": True}, + ], + "count": 3, + }, + "error": None, + } + + out_verbose = render_list(data, show_health=False, show_all=False, verbose=True) + # Shows both MLX models (chat+base) + assert "MLXChat" in out_verbose + assert "MLXBase" in out_verbose + # Hides non-MLX + assert "OtherPT" not in out_verbose + + +def test_list_human_all_shows_all_frameworks(): + from mlxk2.output.human import render_list + + data = { + "status": "success", + "command": "list", + "data": { + "models": [ + {"name": "org/MLXChat", "hash": None, "size_bytes": 1, "last_modified": "2025-08-30T12:00:00Z", "framework": "MLX", "model_type": "chat", "capabilities": ["text-generation", "chat"], "health": "healthy", "cached": True}, + {"name": "org/OtherGGUF", "hash": None, "size_bytes": 1, "last_modified": "2025-08-30T12:00:00Z", "framework": "GGUF", "model_type": "base", "capabilities": ["text-generation"], "health": "unhealthy", "cached": True}, + {"name": "org/OtherPT", "hash": None, "size_bytes": 1, "last_modified": "2025-08-30T12:00:00Z", "framework": "PyTorch", "model_type": "base", "capabilities": ["text-generation"], "health": "healthy", "cached": True}, + ], + "count": 3, + }, + "error": None, + } + + out_all = render_list(data, show_health=False, show_all=True, verbose=False) + assert "MLXChat" in out_all + assert "OtherGGUF" in out_all + assert "OtherPT" in out_all diff --git a/tests_2.0/test_push_extended.py b/tests_2.0/test_push_extended.py index a2672ed..34fbeaf 100644 --- a/tests_2.0/test_push_extended.py +++ b/tests_2.0/test_push_extended.py @@ -213,3 +213,47 @@ def test_push_hfignore_is_merged_with_defaults(tmp_path, monkeypatch): # Ensure .hfignore additions are present assert ".idea/" in pats and ".vscode/" in pats and "*.ipynb" in pats + +def test_push_retry_creates_branch_on_upload_revision_error(tmp_path, monkeypatch): + """If upload fails with a revision-not-found style error and --create is set, + the operation should create the branch and retry once, succeeding offline.""" + monkeypatch.setenv("HF_TOKEN", "dummy") + ws = tmp_path / "ws" + ws.mkdir() + (ws / "file.txt").write_text("x") + + class _ApiOk(_FakeHfApi): + instance = None # type: ignore[var-annotated] + + def __init__(self, token: str | None = None) -> None: # type: ignore[override] + super().__init__(token) + self.created_branches: list[tuple[str, str]] = [] + _ApiOk.instance = self + + def create_branch(self, repo_id: str, repo_type: str, branch: str): # type: ignore[override] + self.created_branches.append((repo_id, branch)) + return {"ok": True} + + state = {"attempt": 0} + + def upload_folder(**kwargs): # type: ignore + # First attempt fails with a hub-like error; second succeeds + if state["attempt"] == 0: + state["attempt"] += 1 + raise _Errors.HfHubHTTPError("Invalid rev id: test-branch") + state["attempt"] += 1 + return SimpleNamespace(commit_id="0123456789abcdef0123456789abcdef01234567") + + fake = SimpleNamespace(HfApi=_ApiOk, upload_folder=upload_folder, errors=_Errors) + sys.modules["huggingface_hub"] = fake # type: ignore + sys.modules["huggingface_hub.errors"] = _Errors # type: ignore + monkeypatch.setitem(sys.modules, "huggingface_hub", fake) + monkeypatch.setitem(sys.modules, "huggingface_hub.errors", _Errors) + + res = push_operation(str(ws), "user/repo", create=True, private=True, branch="test-branch") + assert res["status"] == "success" + # Ensure we retried exactly once (two attempts total) + assert state["attempt"] == 2 + # Ensure branch creation was attempted once + assert _ApiOk.instance is not None + assert ("user/repo", "test-branch") in (_ApiOk.instance.created_branches if _ApiOk.instance else [])