mlx_lm Integration
VeloxQuant-MLX is designed to work as a drop-in extension for mlx_lm. This guide covers the three integration patterns: the KVCacheBuilder helper, the mlx_lm_patch monkey-patch, and the fused SDPA kernel.
Pattern 1 — KVCacheBuilder (recommended)
KVCacheBuilder is the primary integration point. It inspects the model's config and constructs one KVCache per transformer layer, matching num_key_value_heads and head_dim automatically.
import mlx_lm
from veloxquant_mlx.cache.base import KVCacheConfig, KVCacheBuilder
model, tokenizer = mlx_lm.load("mlx-community/Llama-3.2-3B-Instruct-4bit")
config = KVCacheConfig(method="turboquant_rvq", bits=1, value_bits=2)
cache = KVCacheBuilder.build(model, config)
# Pass cache directly to mlx_lm.generate
response = mlx_lm.generate(
model,
tokenizer,
prompt="Hello, world!",
max_tokens=256,
kv_cache=cache,
)
KVCacheBuilder.build() works with any model that exposes model.args.num_hidden_layers, model.args.num_key_value_heads, and model.args.head_dim — which covers all major mlx_lm model families.
Pattern 2 — mlx_lm monkey-patch
The monkey-patch approach automatically intercepts the default KV cache creation inside mlx_lm and replaces it with VeloxQuant-MLX caches. This requires zero changes to your generation call:
import mlx_lm
from veloxquant_mlx.integration.mlx_lm_patch import patch_mlx_lm
model, tokenizer = mlx_lm.load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
config = KVCacheConfig(method="turboquant_rvq", bits=1)
patch_mlx_lm(model, config) # patches model.make_cache()
# No kv_cache argument needed — the patch intercepts it
response = mlx_lm.generate(model, tokenizer, prompt="...", max_tokens=512)
The monkey-patch is useful when integrating with third-party code that calls mlx_lm.generate directly and does not expose a kv_cache argument.
Pattern 3 — Fused SDPA
patch_mlx_lm_for_fused_sdpa replaces the attention computation with a Metal kernel that dequantizes keys/values and computes attention in a single GPU pass:
from veloxquant_mlx.metal.fused_sdpa import patch_mlx_lm_for_fused_sdpa
# Call once after loading the model
patch_mlx_lm_for_fused_sdpa(model)
# Subsequent generate calls use the fused kernel automatically
response = mlx_lm.generate(model, tokenizer, prompt="...", max_tokens=1024, kv_cache=cache)
Fused SDPA is most beneficial when:
- The KV cache is large (long sequences, many layers)
- You are using VecInfer (the fused kernel is optimised for its codebook format)
- Throughput is the priority over latency on individual calls
Check compatibility before patching:
from veloxquant_mlx.metal.fused_sdpa import supports_shape
# Verify your model's attention shape is supported
is_supported = supports_shape(
batch=1,
heads=model.args.num_attention_heads,
seq_len=4096,
head_dim=model.args.head_dim,
)
print(f"Fused SDPA supported: {is_supported}")
Streaming generation
VeloxQuant-MLX caches work transparently with mlx_lm's streaming API:
for token in mlx_lm.stream_generate(
model, tokenizer,
prompt="Tell me a very long story.",
max_tokens=4096,
kv_cache=cache,
):
print(token, end="", flush=True)
Multi-turn conversations
For multi-turn chat, reuse the same cache across turns. The cache grows across turns but retains compression:
config = KVCacheConfig(method="turboquant_rvq", bits=1)
cache = KVCacheBuilder.build(model, config)
turns = [
"What is the capital of France?",
"And what is it known for?",
"What's the population?",
]
for turn in turns:
response = mlx_lm.generate(
model, tokenizer, prompt=turn, max_tokens=200, kv_cache=cache
)
print(f"User: {turn}\nAssistant: {response}\n")
# cache now contains compressed K/V for all prior turns
Cache capacity is bounded by max_seq_len. If the conversation exceeds this, use SlidingWindowKVCache to evict old tokens.
Supported models
All mlx_lm model families have been validated:
| Model family | Recommended config |
|---|---|
| Llama 3.1 / 3.2 / 3.3 | method="turboquant_rvq", bits=1 |
| Mistral 7B / Mixtral | method="vecinfer", bits=2 |
| Qwen 2.5 (7–72B) | method="spectral", signal_bits=4 |
| Phi-3 / Phi-3.5 Mini | method="commvq", bits=2 |
| Gemma 2B / 7B | method="turboquant_rvq", bits=2 |
| Falcon 7B | method="ratequant", target_bits=2.0 |