Skip to main content

Observers API

veloxquant_mlx.observers


DistortionObserver

from veloxquant_mlx.observers.distortion import DistortionObserver

Measures cosine similarity and inner product estimation error between original and quantized keys.

Constructor

DistortionObserver(sample_rate: float = 1.0)
ParameterTypeDefaultDescription
sample_ratefloat1.0Fraction of tokens to measure (1.0 = all). Lower values reduce overhead.

Methods

def attach(self, cache: KVCache | list[KVCache]) -> None
def report(self) -> DistortionReport
def reset(self) -> None

attach(cache) — Registers the observer with one or more cache instances. Hooks into the encode/decode cycle.

report() — Returns a DistortionReport after generation completes.

reset() — Clears accumulated statistics. Call between runs.

DistortionReport

FieldTypeDescription
mean_cosine_similarityfloatAverage cosine sim across all tokens and layers
min_cosine_similarityfloatWorst-case cosine sim
mean_ip_errorfloatMean absolute inner product estimation error
per_layer_cosine_similaritydict[str, float]Per-layer breakdown
worst_layerstrLayer with lowest cosine sim
total_tokens_measuredintTotal tokens included in statistics

LatencyObserver

from veloxquant_mlx.observers.latency import LatencyObserver

Profiles per-call encode and decode latency.

Constructor

LatencyObserver(warmup_calls: int = 2)
ParameterTypeDefaultDescription
warmup_callsint2Calls to skip before recording (exclude Metal JIT warmup)

Methods

def attach(self, cache: KVCache | list[KVCache]) -> None
def report(self) -> LatencyReport
def reset(self) -> None

LatencyReport

FieldTypeDescription
mean_encode_msfloatAverage milliseconds per encode call
mean_decode_msfloatAverage milliseconds per decode call
p99_encode_msfloat99th percentile encode latency
total_encode_msfloatCumulative encode time
total_decode_msfloatCumulative decode time
per_layer_encode_msdict[str, float]Per-layer average encode time
slowest_layerstrLayer with highest total encode time
num_encode_callsintTotal encode calls recorded

MemoryObserver

from veloxquant_mlx.observers.memory import MemoryObserver

Tracks peak memory and computes compression ratio vs fp16 baseline.

Constructor

MemoryObserver()

Methods

def attach(self, cache: KVCache | list[KVCache]) -> None
def report(self) -> MemoryReport
def reset(self) -> None

MemoryReport

FieldTypeDescription
peak_compressed_mbfloatPeak compressed cache memory in MB
peak_fp16_mbfloatEquivalent fp16 cache memory in MB
compression_ratiofloatpeak_fp16_mb / peak_compressed_mb
total_tokensintTotal tokens written to cache
bytes_per_tokenfloatAverage bytes per token across all layers
per_layer_mbdict[str, float]Per-layer peak memory

KeyNormObserver

from veloxquant_mlx.observers.key_norm import KeyNormObserver

Monitors key vector norms and detects outlier tokens.

Constructor

KeyNormObserver(
outlier_threshold: float = 3.0,
window_size: int = 128,
)
ParameterTypeDefaultDescription
outlier_thresholdfloat3.0Norms above mean + threshold × std are outliers
window_sizeint128Rolling window size for computing running statistics

Methods

def attach(self, cache: KVCache | list[KVCache]) -> None
def report(self) -> KeyNormReport
def reset(self) -> None

KeyNormReport

FieldTypeDescription
mean_key_normfloatRolling mean of key L2 norms
std_key_normfloatRolling std of key norms
max_key_normfloatMaximum norm seen
outlier_countintTotal tokens flagged as outliers
outlier_fractionfloatoutlier_count / total_tokens
mean_outlier_normfloatAverage norm of outlier tokens
per_layer_outlier_countdict[str, int]Outliers per layer

Example — all observers together

import mlx_lm
from veloxquant_mlx.cache.base import KVCacheConfig, KVCacheBuilder
from veloxquant_mlx.observers.distortion import DistortionObserver
from veloxquant_mlx.observers.memory import MemoryObserver
from veloxquant_mlx.observers.latency import LatencyObserver
from veloxquant_mlx.observers.key_norm import KeyNormObserver

model, tokenizer = mlx_lm.load("mlx-community/Llama-3.2-3B-Instruct-4bit")
config = KVCacheConfig(method="turboquant_rvq", bits=1)
cache = KVCacheBuilder.build(model, config)

observers = [
DistortionObserver(),
MemoryObserver(),
LatencyObserver(),
KeyNormObserver(outlier_threshold=3.0),
]
for obs in observers:
obs.attach(cache)

mlx_lm.generate(model, tokenizer, prompt="...", max_tokens=1024, kv_cache=cache)

dist, mem, lat, norm = [obs.report() for obs in observers]
print(f"Cosine sim : {dist.mean_cosine_similarity:.4f}")
print(f"Compression: {mem.compression_ratio:.1f}×")
print(f"Encode lat : {lat.mean_encode_ms:.2f} ms")
print(f"Outliers : {norm.outlier_fraction:.1%}")

See also