Skip to main content

Metal Kernels API

veloxquant_mlx.metal

All Metal kernels are compiled lazily on first call via mx.fast.metal_kernel. These are low-level functions — most users should interact with them indirectly through quantizer and cache classes.

:::warning Apple Silicon only All functions in this module require macOS on an M-series chip. On unsupported hardware they raise MetalUnavailableError. :::


Availability check

from veloxquant_mlx.metal import metal_available

if not metal_available():
raise RuntimeError("Metal not available on this device")

VecInfer kernels

veloxquant_mlx.metal._vecinfer

vecinfer_quantize_metal

def vecinfer_quantize_metal(
keys: mx.array,
codebook: mx.array,
smooth_factors: mx.array,
num_subspaces: int,
) -> mx.array

Product VQ encoding on GPU. Returns integer indices of shape [batch, heads, seq, num_subspaces]. 13× faster than equivalent Python ops.


vecinfer_dequant_metal

def vecinfer_dequant_metal(
indices: mx.array,
codebook: mx.array,
smooth_factors: mx.array,
) -> mx.array

Codebook gather + smooth-factor inverse. Returns reconstructed keys of shape [batch, heads, seq, head_dim].


vecinfer_encode_decode_metal

def vecinfer_encode_decode_metal(
keys: mx.array,
codebook: mx.array,
smooth_factors: mx.array,
num_subspaces: int,
) -> tuple[mx.array, mx.array]

Fused encode then decode in one kernel dispatch. Returns (indices, reconstructed_keys).


compute_query_lut

from veloxquant_mlx.allocators.vecinfer import compute_query_lut

def compute_query_lut(
queries: mx.array,
codebook: mx.array,
smooth_factors: mx.array,
) -> mx.array

Precomputes a query-codebook distance look-up table for asymmetric MIPS (Maximum Inner Product Search). Returns [batch, heads, num_subspaces, num_centroids].


RaBitQ kernels

veloxquant_mlx.metal._rabitq

rabitq_hamming_score

def rabitq_hamming_score(
query_bits: mx.array,
key_bits: mx.array,
scale: float,
) -> mx.array

XOR + popcount Hamming distance as inner product proxy. Uses Metal native instructions.

  • query_bits: packed uint32, shape [batch, heads, 1, words]
  • key_bits: packed uint32, shape [batch, heads, seq, words]
  • Returns: [batch, heads, 1, seq] attention score approximation

CommVQ kernels

veloxquant_mlx.metal._comm_vq

comm_vq_decode_metal

def comm_vq_decode_metal(
indices: mx.array,
codebook: mx.array,
cos_freqs: mx.array,
sin_freqs: mx.array,
positions: mx.array,
) -> mx.array

Fused centroid gather + RoPE application in a single Metal pass. Returns decoded+position-embedded keys.


Scalar quantization kernels

veloxquant_mlx.metal._scalar_quant

turboquant_scalar_quantize

def turboquant_scalar_quantize(x: mx.array, bits: int) -> mx.array

Lloyd-Max scalar quantization on GPU.

turboquant_scalar_dequantize

def turboquant_scalar_dequantize(indices: mx.array, bits: int, scale: float) -> mx.array

turboquant_hadamard_quantize

def turboquant_hadamard_quantize(x: mx.array, bits: int) -> tuple[mx.array, mx.array]

Fused WHT rotation + scalar quantization in one pass. Returns (indices, scale_factors).


RVQ + Attention fusion

veloxquant_mlx.metal._rvq_attend

turboquant_fused_rvq_decode_attend

def turboquant_fused_rvq_decode_attend(
queries: mx.array,
encoded_keys: EncodedVector,
values: mx.array,
scale: float,
) -> mx.array

Two-stage RVQ decode + scaled dot-product attention in a single kernel. Most efficient path for TurboQuant RVQ inference.


Fused SDPA

veloxquant_mlx.metal.fused_sdpa

metal_fused_sdpa

from veloxquant_mlx.metal.fused_sdpa import metal_fused_sdpa

def metal_fused_sdpa(
queries: mx.array,
encoded_keys: EncodedVector,
values: mx.array,
scale: float,
mask: mx.array | None = None,
) -> mx.array

Fused dequantize + scaled dot-product attention. Supports all VeloxQuant-MLX key formats.

supports_shape

def supports_shape(batch: int, heads: int, seq_len: int, head_dim: int) -> bool

Returns True if the fused kernel supports this attention shape. Requires head_dim to be a multiple of 32.

patch_mlx_lm_for_fused_sdpa

from veloxquant_mlx.metal.fused_sdpa import patch_mlx_lm_for_fused_sdpa

def patch_mlx_lm_for_fused_sdpa(model) -> None

Monkey-patches each attention layer to use metal_fused_sdpa instead of standard mx.matmul. Call once after model load.


Bit packing

veloxquant_mlx.metal._bit_packing

turboquant_bit_pack

def turboquant_bit_pack(indices: mx.array, bits: int) -> mx.array

Packs bits-bit indices into uint32 words. Input shape [..., N], output shape [..., ceil(N*bits/32)].

turboquant_bit_unpack

def turboquant_bit_unpack(
packed: mx.array,
bits: int,
original_length: int,
) -> mx.array

Unpacks uint32 words back to int32 indices.


QJL kernels

veloxquant_mlx.metal._qjl

qjl_encode

def qjl_encode(keys: mx.array, projection: mx.array) -> mx.array

Project + sign in one Metal pass. Returns packed uint32 bit strings.

qjl_inner_product

def qjl_inner_product(
query_bits: mx.array,
key_bits: mx.array,
head_dim: int,
sketch_dim: int,
) -> mx.array

Approximates ⟨q, k⟩ via bit string inner product.


See also