Metal Kernels API
veloxquant_mlx.metal
All Metal kernels are compiled lazily on first call via mx.fast.metal_kernel. These are low-level functions — most users should interact with them indirectly through quantizer and cache classes.
:::warning Apple Silicon only
All functions in this module require macOS on an M-series chip. On unsupported hardware they raise MetalUnavailableError.
:::
Availability check
from veloxquant_mlx.metal import metal_available
if not metal_available():
raise RuntimeError("Metal not available on this device")
VecInfer kernels
veloxquant_mlx.metal._vecinfer
vecinfer_quantize_metal
def vecinfer_quantize_metal(
keys: mx.array,
codebook: mx.array,
smooth_factors: mx.array,
num_subspaces: int,
) -> mx.array
Product VQ encoding on GPU. Returns integer indices of shape [batch, heads, seq, num_subspaces]. 13× faster than equivalent Python ops.
vecinfer_dequant_metal
def vecinfer_dequant_metal(
indices: mx.array,
codebook: mx.array,
smooth_factors: mx.array,
) -> mx.array
Codebook gather + smooth-factor inverse. Returns reconstructed keys of shape [batch, heads, seq, head_dim].
vecinfer_encode_decode_metal
def vecinfer_encode_decode_metal(
keys: mx.array,
codebook: mx.array,
smooth_factors: mx.array,
num_subspaces: int,
) -> tuple[mx.array, mx.array]
Fused encode then decode in one kernel dispatch. Returns (indices, reconstructed_keys).
compute_query_lut
from veloxquant_mlx.allocators.vecinfer import compute_query_lut
def compute_query_lut(
queries: mx.array,
codebook: mx.array,
smooth_factors: mx.array,
) -> mx.array
Precomputes a query-codebook distance look-up table for asymmetric MIPS (Maximum Inner Product Search). Returns [batch, heads, num_subspaces, num_centroids].
RaBitQ kernels
veloxquant_mlx.metal._rabitq
rabitq_hamming_score
def rabitq_hamming_score(
query_bits: mx.array,
key_bits: mx.array,
scale: float,
) -> mx.array
XOR + popcount Hamming distance as inner product proxy. Uses Metal native instructions.
query_bits: packed uint32, shape[batch, heads, 1, words]key_bits: packed uint32, shape[batch, heads, seq, words]- Returns:
[batch, heads, 1, seq]attention score approximation
CommVQ kernels
veloxquant_mlx.metal._comm_vq
comm_vq_decode_metal
def comm_vq_decode_metal(
indices: mx.array,
codebook: mx.array,
cos_freqs: mx.array,
sin_freqs: mx.array,
positions: mx.array,
) -> mx.array
Fused centroid gather + RoPE application in a single Metal pass. Returns decoded+position-embedded keys.
Scalar quantization kernels
veloxquant_mlx.metal._scalar_quant
turboquant_scalar_quantize
def turboquant_scalar_quantize(x: mx.array, bits: int) -> mx.array
Lloyd-Max scalar quantization on GPU.
turboquant_scalar_dequantize
def turboquant_scalar_dequantize(indices: mx.array, bits: int, scale: float) -> mx.array
turboquant_hadamard_quantize
def turboquant_hadamard_quantize(x: mx.array, bits: int) -> tuple[mx.array, mx.array]
Fused WHT rotation + scalar quantization in one pass. Returns (indices, scale_factors).
RVQ + Attention fusion
veloxquant_mlx.metal._rvq_attend
turboquant_fused_rvq_decode_attend
def turboquant_fused_rvq_decode_attend(
queries: mx.array,
encoded_keys: EncodedVector,
values: mx.array,
scale: float,
) -> mx.array
Two-stage RVQ decode + scaled dot-product attention in a single kernel. Most efficient path for TurboQuant RVQ inference.
Fused SDPA
veloxquant_mlx.metal.fused_sdpa
metal_fused_sdpa
from veloxquant_mlx.metal.fused_sdpa import metal_fused_sdpa
def metal_fused_sdpa(
queries: mx.array,
encoded_keys: EncodedVector,
values: mx.array,
scale: float,
mask: mx.array | None = None,
) -> mx.array
Fused dequantize + scaled dot-product attention. Supports all VeloxQuant-MLX key formats.
supports_shape
def supports_shape(batch: int, heads: int, seq_len: int, head_dim: int) -> bool
Returns True if the fused kernel supports this attention shape. Requires head_dim to be a multiple of 32.
patch_mlx_lm_for_fused_sdpa
from veloxquant_mlx.metal.fused_sdpa import patch_mlx_lm_for_fused_sdpa
def patch_mlx_lm_for_fused_sdpa(model) -> None
Monkey-patches each attention layer to use metal_fused_sdpa instead of standard mx.matmul. Call once after model load.
Bit packing
veloxquant_mlx.metal._bit_packing
turboquant_bit_pack
def turboquant_bit_pack(indices: mx.array, bits: int) -> mx.array
Packs bits-bit indices into uint32 words. Input shape [..., N], output shape [..., ceil(N*bits/32)].
turboquant_bit_unpack
def turboquant_bit_unpack(
packed: mx.array,
bits: int,
original_length: int,
) -> mx.array
Unpacks uint32 words back to int32 indices.
QJL kernels
veloxquant_mlx.metal._qjl
qjl_encode
def qjl_encode(keys: mx.array, projection: mx.array) -> mx.array
Project + sign in one Metal pass. Returns packed uint32 bit strings.
qjl_inner_product
def qjl_inner_product(
query_bits: mx.array,
key_bits: mx.array,
head_dim: int,
sketch_dim: int,
) -> mx.array
Approximates ⟨q, k⟩ via bit string inner product.