Skip to content

osaurus-ai/mlx-swift-lm

 
 

Repository files navigation

MLX Swift LM

by Osaurus | Fork of ml-explore/mlx-swift-lm

A Swift package for building applications with large language models (LLMs) and vision language models (VLMs) on Apple Silicon, powered by MLX Swift.

This fork adds native JANG mixed-precision quantization, TurboQuant KV cache compression (4.7-5.0x memory savings), Gemma 4, Mistral Small 4, speculative decoding, VLM detection, and MoE performance optimizations on top of the full upstream library. Existing apps don't need to change anything -- all upstream APIs are preserved.

What's New in This Fork

New Model Architectures

Gemma 4 -- Google's latest, with both MoE and dense variants:

Variant Params Architecture VLM
26B (A4B) 26B total, 4B active MoE (128 experts, top-8) Yes
31B 31B dense Mixed sliding/full attention Yes

Mistral Small 4 -- 119B MoE with Multi-head Latent Attention:

Variant Params Architecture VLM
119B (A8B) 119B total, 8B active MLA + 128 experts + shared expert Yes (Pixtral)

JANG Mixed-Precision Quantization

JANG models use per-layer mixed-precision -- attention at 6-8 bit, MLP/experts at 2-4 bit -- for better quality at the same memory. Loaded natively with zero code changes:

// Loading a JANG model is identical to any other model
let container = try await loadModelContainer(
    from: URL(filePath: "/path/to/Gemma-4-26B-A4B-it-JANG_4M"),
    using: TokenizersLoader()
)

Performance

MoE models run up to 4x faster than upstream thanks to computation graph optimization, bfloat16 dtype, compiled activations, and Metal memory management:

Model Upstream This Fork Gain
Gemma 4 26B MoE 25.0 tok/s 101 tok/s +304%
Qwen 3.5-35B MoE 42.4 tok/s 61 tok/s +44%
NemotronH 30B-A3B ~25 tok/s 48 tok/s +92%
Qwen 3.5-4B Dense 123 tok/s 145 tok/s +18%

Key optimizations:

  • Computation graph cleanup: Eliminated 86 unnecessary .asType() MLX graph nodes per Gemma4 MoE forward pass. Each redundant type cast was a separate Metal kernel dispatch.
  • Compiled logit softcap: Fused divide + tanh + multiply into a single Metal dispatch via compile(shapeless: true), matching Python's @mx.compile decorator.
  • Periodic Metal cache cleanup: Memory.clearCache() every 256 tokens reduces GPU allocator fragmentation from MoE expert weight cycling.
  • GPU memory pinning: mlx_set_wired_limit via Cmlx prevents macOS from paging model weights to SSD between tokens.
  • bfloat16 MoE conversion: Prevents Metal's automatic float16 to float32 promotion on mixed-dtype MoE operations.
  • Compiled GLU activations: Fused SwiGLU/GeGLU into single Metal dispatches.
  • Symlink resolution: Properly follows symlinked model directories (mlxstudio compatibility).

Speculative Decoding

Use a smaller draft model to speed up generation by 29-79% (cherry-picked from upstream ml-explore#173):

let mainModel = try await loadModelContainer(
    from: HubClient.default, using: TokenizersLoader(),
    id: "mlx-community/Qwen3-14B-4bit")
let draftModel = try await loadModelContainer(
    from: HubClient.default, using: TokenizersLoader(),
    id: "mlx-community/Qwen3-0.6B-4bit")

let result = try await mainModel.generate(
    input: input, parameters: params, draft: draftModel)

VLM Detection

Check at runtime whether a model supports vision input:

if await container.isVLM {
    // safe to pass images
}

Works from MLXLMCommon alone -- no need to import MLXVLM.

TurboQuant KV Cache Compression

Compress the KV cache 4.7-5.0x during inference with no quality loss on short outputs and minimal divergence on long outputs. Based on Google DeepMind's research (arXiv:2504.19874).

One line to enable, works with every model -- no model changes needed:

// 3-bit (recommended default -- best compression)
let params = GenerateParameters(
    kvMode: .turboQuant(keyBits: 3, valueBits: 3))

// 4-bit (higher quality, less compression)
let params = GenerateParameters(
    kvMode: .turboQuant(keyBits: 4, valueBits: 4))

Works with ChatSession for multi-turn conversations:

let session = ChatSession(
    container,
    generateParameters: GenerateParameters(
        kvMode: .turboQuant(keyBits: 3, valueBits: 3)))

let reply1 = try await session.respond(to: "What is the capital of Japan?")
// "Tokyo"
let reply2 = try await session.respond(to: "What country is that city in?")
// "Japan" -- context preserved across turns

Works with speculative decoding:

let params = GenerateParameters(
    kvMode: .turboQuant(keyBits: 3, valueBits: 3))
let result = try await mainModel.generate(
    input: input, parameters: params, draft: draftModel)

How It Works

TurboQuant compresses the KV cache after prefill using three techniques:

  1. Randomized Hadamard rotation -- spreads information uniformly across all dimensions so a single codebook works optimally for every component
  2. Lloyd-Max optimal codebook -- minimizes quantization error for the statistical distribution of rotated vector components
  3. QJL residual correction (keys only) -- 1-bit random projection that corrects the exponential error amplification in softmax attention scores

The compressed cache is decoded once into a float16 buffer. During generation, new tokens are scatter-written into a pre-allocated window. Models see normal float16 arrays from update() -- they never know compression happened.

Memory Savings

Model Context Float Cache TurboQuant-3 Savings
Gemma 4 26B MoE 2K 84 MB 17 MB 4.9x
Qwen 3.5-35B 32K 655 MB 135 MB 4.9x
Mistral Small 4 (119B) 2K 1,208 MB 244 MB 4.9x

Tested Configurations

Model Architecture Format Modes Result
Gemma 4 26B MoE (128 experts) MLX 4-bit LLM, VLM, multi-turn Identical on short, near-identical on long
Gemma 4 31B Dense MLX 4-bit LLM, multi-turn Identical on short, near-identical on long
Gemma 4 31B Dense JANG 4M LLM Identical
NemotronH 30B-A3B Hybrid SSM/attention JANG 4M LLM, multi-turn Identical
NemotronH 30B-A3B Hybrid SSM/attention JANG 2L LLM Near-identical

TurboQuant automatically skips non-KV cache layers (MambaCache for SSM, RotatingKVCache for sliding window). If maxKVSize is set (all RotatingKVCache), TurboQuant gracefully does nothing.


Supported Models

LLMs (50+ architectures)

Llama, Mistral, Phi, Phi-3, Phi-MoE, Gemma, Gemma 2, Gemma 3, Gemma 3n, Gemma 4, Qwen2, Qwen3, Qwen3-MoE, Qwen3.5, Qwen3.5-MoE, DeepSeek-V3, Cohere, OpenELM, InternLM2, Starcoder2, MiniCPM, Granite, Granite-MoE-Hybrid, MiMo, MiMo-V2-Flash, MiniMax, GLM-4, GLM-4-MoE, Falcon-H1, Bitnet, SmolLM3, ERNIE 4.5, LFM2, LFM2-MoE, Baichuan-M1, Exaone4, GPT-OSS, Lille-130m, OLMoE, OLMo2, OLMo3, Bailing-MoE, NanoChat, Nemotron-H, AF-MoE, Jamba, Mistral Small 4 (MLA + MoE), Mistral3, Apertus, and more.

VLMs (17+ architectures)

PaliGemma, Qwen2-VL, Qwen2.5-VL, Qwen3-VL, Qwen3.5, Qwen3.5-MoE, Gemma 3, Gemma 4, SmolVLM2, FastVLM, Pixtral, Mistral Small 4 (MLA + Pixtral), Mistral3, LFM2-VL, GLM-OCR, Idefics3, and more.

Embedders

Sentence Transformers, BERT, and other popular embedding models.


Quick Start

Add the package to your Package.swift:

.package(url: "https://github.com/osaurus-ai/mlx-swift-lm", branch: "main"),

Then add tokenizer and downloader integrations:

.package(url: "https://github.com/DePasqualeOrg/swift-tokenizers-mlx", from: "0.1.0"),
.package(url: "https://github.com/DePasqualeOrg/swift-hf-api-mlx", from: "0.1.0"),

And add the libraries to your target:

.target(
    name: "YourTargetName",
    dependencies: [
        .product(name: "MLXLLM", package: "mlx-swift-lm"),
        .product(name: "MLXLMTokenizers", package: "swift-tokenizers-mlx"),
        .product(name: "MLXLMHuggingFace", package: "swift-hf-api-mlx"),
    ]),

Chat Session

import MLXLLM
import MLXLMHuggingFace
import MLXLMTokenizers

let model = try await loadModel(
    from: HubClient.default,
    using: TokenizersLoader(),
    id: "mlx-community/Qwen3-4B-4bit"
)
let session = ChatSession(model)
print(try await session.respond(to: "What are two things to see in San Francisco?"))

Loading a Local Model

import MLXLLM
import MLXLMTokenizers

// Works for any model -- standard MLX, JANG, or unquantized
let container = try await loadModelContainer(
    from: URL(filePath: "/path/to/model"),
    using: TokenizersLoader()
)

JANG models are detected automatically. No special flags needed.

Checking VLM Support

let container = try await loadModelContainer(from: modelDirectory, using: TokenizersLoader())

if await container.isVLM {
    // Model supports images -- can pass UserInput with .images
} else {
    // Text-only model
}

You can also check before loading, using the model type string from config.json:

import MLXVLM

// Synchronous -- no actor isolation needed
if VLMTypeRegistry.supportedModelTypes.contains(modelType) {
    // This model_type is a known VLM architecture
}

VLM-capable families: Gemma 4, Gemma 3, Qwen 3.5 VL, Qwen 3 VL, Qwen 2.5 VL, Mistral Small 4, Mistral 3, PaliGemma, Pixtral, SmolVLM2, FastVLM, Idefics3, LFM2-VL, GLM-OCR.

Tokenizer and Downloader Integrations

MLX Swift LM focuses on model implementations. Tokenization and downloading are handled by separate packages:

Downloader Adapter
huggingface/swift-huggingface DePasqualeOrg/swift-huggingface-mlx
DePasqualeOrg/swift-hf-api DePasqualeOrg/swift-hf-api-mlx
Tokenizer Adapter
DePasqualeOrg/swift-tokenizers DePasqualeOrg/swift-tokenizers-mlx
huggingface/swift-transformers DePasqualeOrg/swift-transformers-mlx

Note: Adapters are optional. You can set up protocol conformance directly. See the adapter packages for examples.


How JANG Loading Works

  1. Detection -- Factory checks for jang_config.json in the model directory.
  2. Config parsing -- JangLoader reads the JANG profile (bit widths, block size, source model info).
  3. Weight loading -- Standard .safetensors files loaded normally (JANG v2 is MLX-native).
  4. Sanitize -- Model-specific weight key remapping (VLM prefix stripping, expert key normalization).
  5. Gate dequantization -- MoE gate weights restored to bfloat16 for routing precision.
  6. Quantization inference -- Per-layer bit widths inferred from tensor shapes.
  7. Apply -- Inferred per-layer quantization replaces uniform quantization from config.json.

If jang_config.json doesn't exist, the standard MLX loading path runs unchanged.


Migrating from Upstream

Change your package URL:

// Before
.package(url: "https://github.com/ml-explore/mlx-swift-lm", branch: "main"),

// After
.package(url: "https://github.com/osaurus-ai/mlx-swift-lm", branch: "main"),

Everything else stays the same. You gain JANG support, Gemma 4, Mistral Small 4, speculative decoding, isVLM, and MoE performance boosts for free.

If migrating from upstream 2.x, see the version 3 migration guide below.

Migrating to Version 3

Version 3 decouples tokenizer and downloader implementations.

New dependencies

// Before (2.x)
.package(url: "https://github.com/ml-explore/mlx-swift-lm/", from: "2.30.0"),

// After (3.x)
.package(url: "https://github.com/osaurus-ai/mlx-swift-lm/", branch: "main"),
.package(url: "https://github.com/DePasqualeOrg/swift-tokenizers-mlx/", from: "0.1.0"),
.package(url: "https://github.com/DePasqualeOrg/swift-hf-api-mlx/", from: "0.1.0"),

New imports

// Before (2.x)
import MLXLLM

// After (3.x)
import MLXLLM
import MLXLMHuggingFace  // Downloader adapter
import MLXLMTokenizers   // Tokenizer adapter

API changes

  • hub: parameter is now from: (accepts any Downloader or local URL)
  • HubApi is now HubClient
  • decode(tokens:) is renamed to decode(tokenIds:)
// Before (2.x)
let container = try await loadModelContainer(id: "mlx-community/Qwen3-4B-4bit")

// After (3.x)
let container = try await loadModelContainer(
    from: HubClient.default,
    id: "mlx-community/Qwen3-4B-4bit"
)

Documentation


Files Changed vs. Upstream

File Change Purpose
MLXLLM/Models/Gemma4Text.swift New Gemma 4 text (MoE + Dense, dual attention, v_norm, K=V)
MLXVLM/Models/Gemma4.swift New Gemma 4 VLM (vision encoder, 2D RoPE, pooler, processor)
MLXLLM/Models/Mistral4.swift New Mistral Small 4 (MLA attention, 128-expert MoE, YaRN RoPE)
MLXVLM/Models/Mistral4VLM.swift New Mistral Small 4 VLM (MLA text + Pixtral vision)
MLXLMCommon/JangLoader.swift New JANG detection, config, per-layer quant, gate dequant
MLXLMCommon/Load.swift Modified JANG pipeline, VLM key remap, bfloat16 MoE conversion
MLXLMCommon/SwitchLayers.swift Modified Compiled SwiGLU/GeGLU activation kernels
MLXLMCommon/LanguageModel.swift Modified VisionLanguageModelProtocol for isVLM
MLXLMCommon/ModelFactory.swift Modified ModelContext.isVLM
MLXLMCommon/ModelContainer.swift Modified ModelContainer.isVLM
MLXLMCommon/Tool/ToolCallFormat.swift Modified Gemma 4, Gemma 3, MiniMax tool call formats
MLXLLM/LLMModelFactory.swift Modified gemma4, mistral4 registrations
MLXVLM/VLMModelFactory.swift Modified gemma4/mistral4 VLM + processor dispatch
MLXLLM/Models/NemotronH.swift Modified JANG key remap for Nemotron MoE
MLXVLM/Models/Qwen35.swift Modified JANG VLM sanitize fix

Roadmap

  • Native TurboQuant -- Quantization-aware weight format for faster loading
  • Paged KV Cache -- Memory-efficient caching for long contexts
  • Prefix Caching -- Reuse KV cache across prompts with shared prefixes
  • Async L2 Disk Cache -- Spill KV cache to disk for very long contexts

Known Limitations

  • Raw HuggingFace checkpoints -- JANG and mlx-community pre-converted models are supported. Raw HF transformers checkpoints (with fused gate_up_proj) require conversion first.
  • Audio -- Gemma 4 supports audio natively, but the audio encoder is not yet implemented.
  • Gemma 4 2B/4B -- Per-layer input gating and KV sharing for smaller variants not yet implemented.
  • Speculative decoding + RotatingKVCache -- Speculative decoding requires trimmable caches. Not compatible after cache wraps.

License

MIT License. See LICENSE for details.

Based on mlx-swift-lm by Apple's ML Explore team.

Acknowledgments

About

mlx-swift-lm but with Gemma 4 Support and Native Turboquant

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Swift 100.0%