lewm-models / docs /quantization.md
eren23
Initial: LeWM model collection with full quantization documentation
6cdcc30

Quantization Deep Dive

Overview

We explored multiple quantization strategies for LeWM. This document records what was tried, why it worked or didn't, and the engineering trade-offs.

Formats Evaluated

Format Bits/weight Compression Cos vs f32 Status
INT8 (encoder) 8 4x 0.9999 Production
Q4 (predictor) 4 2x more 0.998 Production
INT8+Q4 (full) mixed 5x total 0.999 Production
Q4 encoder only 4 4x 0.93 Rejected
Ternary {-1,0,+1} 2 8x ~0.85 Rejected
WANDA 20% sparse Q4 20% fewer ~0.99 Experimental
WANDA 40% sparse Q4 40% fewer ~0.97 Experimental

INT8 Per-Channel (Encoder)

Method

f32_weight [out, in] β†’ per-output-channel quantization
β†’ int8_weight [out, in]
β†’ per-output-channel scales [out]

At inference:
  result = input @ dequant(weight) Γ— scales
         = input @ (int8_weight + zero_point) Γ— scales
         = (input @ int8_weight) Γ— scales   // zero_point = 0 for symmetric

Why It Works

  1. Per-channel preserves channel-level statistics. Each output channel gets its own scale.
  2. Symmetric (zero_point = 0) avoids the overhead of asymmetric quantization.
  3. INT8 GEMV on PIE SIMD: 16-wide multiply-accumulate per cycle.
  4. Encoder activations have predictable dynamic range after LayerNorm.

Quality

Layer cos vs f32
patch_embed 0.9999
encoder layer 0 0.9999
encoder layer 5 0.9998
encoder.proj 0.9999
Total 0.9999

Engineering Notes

  • Activation quantization is dynamic (per-row at inference time), not static.
  • QKV shared quantization: same normalized input quantized once for Q, K, V.
  • Scales stored as f32 (4 bytes per channel) β€” negligible overhead.

Q4 Block (Predictor)

Method

f32_weight [out, in] β†’ per-32-element-block quantization
β†’ nibble-packed weight data
β†’ per-block f16 scales

At inference:
  for each block of 32:
    unpack nibbles β†’ int8 [-8, 7]
    dot = simd_dot(input_block, unpacked)
    result += dot Γ— block_scale

Why It Works

  1. Per-block (32 elements) matches the predictor's adaLN normalization.
  2. adaLN modulation provides implicit normalization β€” weights don't need per-channel precision.
  3. Smaller layers = less error accumulation. 4-layer predictor vs 6-layer encoder.
  4. PIE SIMD handles the nibble unpack + dot in tight loops.

Why Full Q4 (Encoder + Predictor) Fails

When we skip INT8 and quantize the encoder to Q4:

Layer INT8 cos Q4 cos
encoder 0.9999 0.93
predictor 0.998 0.998
Total 0.999 0.93

Root cause: ViT encoder has high dynamic range in intermediate activations. The 32-element block granularity doesn't align with the encoder's channel statistics. INT8's per-channel precision is essential.

Engineering Notes

  • Nibbles decode as value - 8 β†’ range [-8, 7]
  • Block scales stored as f16 (2 bytes per block) β€” 64 bytes per 32Γ—32 block
  • Zero weights are rare in Q4 (<1%) β€” skip-zero optimization not implemented

Ternary ({-1, 0, +1})

Method

f32_weight β†’ hard threshold
β†’ +1 if w > +tau
β†’ -1 if w < -tau
β†’  0 otherwise
β†’ bit-packed ternary

At inference:
  result = input @ ternary_weight
         = sum(sign(w_i) Γ— x_i)   // pure addition/subtraction

Why It Fails

Metric Q4 Ternary
Cos vs f32 0.998 ~0.85
Compression 8x vs f32 16x vs f32

Root cause: adaLN generates 6 modulation vectors (scale1, shift1, gate1, scale2, shift2, gate2) that multiply and add to the normalized activations. The magnitudes of these modulation vectors matter β€” ternary destroys them.

What was tried:

  • Various thresholds (Ο„ = 0.5Οƒ, Οƒ, 2Οƒ)
  • STE (straight-through estimator) during fine-tuning
  • Mixed ternary (ternary weights + fp32 scales)

None recovered quality sufficiently.

WANDA Pruning

Method

1. Forward pass on calibration set β†’ collect activations
2. Compute WANDA score: s(w) = |w| Γ— ||a||
3. Sort scores, prune bottom N%
4. Fine-tune 1 epoch
5. Re-quantize remaining weights to Q4

Results

Model Pruned Size Cos Notes
Baseline Q4 0% 23.6 MB 0.998 Reference
WANDA 20% 20% 22.0 MB ~0.99 Pruned, no fine-tune
WANDA 40% 40% 25.1 MB ~0.97 Bitmap overhead exceeds savings

Engineering Notes

  • 40% pruned model has larger binary than 20% due to bitmap overhead
  • Skip-zero GEMV needs hardware support (not implemented in PIE SIMD)
  • Would benefit from fine-tuning after pruning

Shared Lessons

  1. Per-channel > per-block for encoders: Encoders have high per-channel variance. INT8's per-channel precision beats Q4's per-block.

  2. Predictors are quantization-friendly: adaLN provides implicit normalization. Predictors can use Q4 with minimal quality loss.

  3. Architecture changes beat quantization: hybrid_ALAL (3.0M params) achieves similar quality to slim_96d (10.2M params) at INT8+Q4. Architecture > precision.

  4. Epoch 1 models have headroom: All current slim/epoch-1 models will improve with longer training. The quality comparisons should be re-run at convergence.

  5. Hardwired is the limit: Q4 weights decompose to shift-add operations. Zero multiplications, zero memory fetches. That's the theoretical floor.