Add A/B comparison mode for P4 fused vs unfused RMSNorm+quantize

- Added --ab-compare flag to run both fused and unfused paths for first 3 layers
- Compares x_normed, gsa values, FP4 data, and GEMM outputs (q_a, kv)
- Added --no-fused-rmsnorm to disable P4 and use unfused path
- This will help diagnose the correctness regression introduced by P4
This commit is contained in:
2026-06-02 18:49:30 +00:00
parent b8bab01a55
commit eb5ef93bf1
3 changed files with 138 additions and 6 deletions

35
NEXT_SESSION.md Normal file
View File

@@ -0,0 +1,35 @@
Here's the summary for next session:
## Summary: What Was Done & What to Investigate
### What was done today
1. **P4 — Fused RMSNorm + NVFP4 quantize**: New CUDA kernel (2 launches vs 6+). Integrated into `single_shot_inference.py` for the attention path. Unit test passes (cos=0.996-0.999 vs unfused). Decode speed improved 0.53→0.43s/token.
2. **P5 — Fused mHC + RMSNorm + NVFP4 quantize**: New CUDA kernel (2 launches vs 7+). Unit test passes on B200 (cos=0.996-0.999). NOT yet integrated into single_shot — needs moe_forward refactoring.
3. **CRITICAL FIX — Indexer deadlock**: `__syncthreads()` inside a strided loop caused deadlock at production context lengths (always deadlocked when `num_valid % 128 != 0`). Fixed with per-thread local top-k + block-level merge. Both copies updated.
### The regression
After P4 integration, model output went from coherent English to garbled nonsense. This is a NEW regression — output was correct last night after P0-P3 + KV work.
### What to investigate (ordered by likelihood)
1. **`run_from_quantized` gsa shape bug (MOST LIKELY)**: The fused kernel produces per-row gsa (shape `(M,)`). `run_from_quantized` passes this to the CuTeDSL NVFP4 GEMM as `global_scale_a`. But the GEMM expects a **single scalar** — it's one scale for the entire A matrix. For M=1 decode, shape `(1,)` may work as a scalar, BUT the gsa VALUE differs from the unfused path's scalar gsa because:
- Unfused: quantize computes gsa from the full (M, N) tensor — single scalar
- Fused: computes gsa per-row, then `[:1].reshape(1)` takes only the first row's gsa
- These differ when rows have different magnitudes
2. **Dequant→requant noise for compressor**: The compressor gets `x_normed` by dequantizing the fused kernel's FP4 output. This introduces ~0.5% quantization error (cos=0.994) that wasn't present before. This noise propagates into compression scores and indexer queries.
3. **A/B test first**: Run with `_use_fused_rmsnorm_quantize=False` to confirm P4 is the cause. If output is correct with unfused, the bug is in the P4 path.
### Key code paths
- **`dsv4/layers/linear.py`**: `run_from_quantized()` — the gsa handling is suspicious
- **`dsv4/ops/gemm_runner.py`**: `run_nvfp4_grouped_gemm()` — how global_scale_a is consumed
- **`dsv4/kernels/gemm/grouped.py`**: CuTeDSL GEMM — global_scale_a is scalar
- **`single_shot_inference.py:855-870`**: P4 integration point
- **`dsv4/kernels/cuda/fused_rmsnorm_quantize.cu`**: the fused kernel
### Quick fix ideas
- Use **scalar gsa** (reduce per-row gsa to a single max) for the GEMM — keeps the fusion but makes gsa compatible
- Or: don't use `run_from_quantized` — just use the fused kernel for the BF16 output (dequant) and let each linear re-quantize with its own scalar gsa (saves rmsnorm launches but not quantize launches)
- Or: fix the CuTeDSL GEMM to support per-row global_scale_a (THIS IS COMPLEX, BUT IF IT IS THE CORRECT WAY OF DOING THINGS. I WANT IT DONE THAT WAY!!!!!)

View File

@@ -120,7 +120,7 @@ FP32 cos/sin cache, forward + inverse, in-place operation.
Total BF16 at 1M context: ~10 GB on 8×B200. Fits comfortably, so **KV quantization
is a throughput question, not a memory question.**
## Why FP4 storage is the right answer for the compressed streams
## Why FP4 storage is the right answer for the compressed streams - THIS IS NOT WHAT WE ENDED UP USING BECAUSE THE COSINE WAS TOO FAR OFF,
Three reasons, in priority order:
@@ -168,6 +168,37 @@ when E7 lands).
- **Recall@k for indexer ≥ 99% vs FP32 oracle** (the bar from the prior
indexer-fix audit). Critical — FP4 must not corrupt top-k ranking.
### THE ABOVE DID NOT WORK... WHY NOT NVFP4 (native Blackwell FP4)?
─────────────────────────────────────
We *really* wanted to use NVFP4 (E2M1 + E4M3 block scales + FP32 global scale)
for compressed KV storage. Blackwell's native FP4→MMA path would have given us
3.5× memory savings and direct tensor-core consumption — the dream pipeline.
We tried. Hard. Three separate approaches:
1. Fused compressor_reduce_quant.cu — single-kernel compress→NVFP4. Bugs in
cross-warp block amax reduction and shared memory corruption (s_scratch
stomping adjacent variables). Best cos=0.703. Dead.
2. Proven two-kernel path (amax_gsa → quantize_from_buffer) using kv_quantize.cu's
compute_amax_gsa_fp32 + quantize_nvfp4_from_fp32. cos=0.995 on random data,
but that's the *quantize/dequant* round-trip in isolation. In the full pipeline,
the 4-bit precision on 448 non-RoPE dimensions accumulated error across 61 layers
of mHC — residual |X| already grows to 300-500, and NVFP4's 16-element block
quantization (4.5 bits effective) added ~0.5% per layer on top of that.
3. FP32 RoPE kernel (rope_fp32 in kv_quantize.cu) to avoid BF16 RoPE intermediate.
Had an indexing bug (cos=0.977 for M>1). Fixed but the real issue was NVFP4,
not RoPE.
The verdict: NVFP4's 4.5 effective bits per element is simply too coarse for
compressed KV values that get summed in attention softmax. FP8_E4M3's 5.3 effective
bits gives cos=0.9997 round-trip (vs NVFP4's 0.995) — that 0.4% difference compounds
fatally across 61 layers.
We settled on FP8_E4M3 for non-RoPE + BF16 for RoPE — exactly what DeepSeek V4
ships in production!!!!!!!! Not because we couldn't build the NVFP4 path (we did, it compiled
and ran), but because the math didn't hold up. Sometimes 4 bits isn't enough.
If Blackwell adds a finer-grained FP4 variant (8-element blocks, 6 effective bits),
revisit this. The kernels exist. The quantize/dequant path is proven. The precision
just isn't there yet for attention-sensitive KV values.
---
# PART 3 — OTHER FUSION WINS, RANKED BY EFFORT/IMPACT

View File

@@ -27,6 +27,8 @@ def parse_args():
p.add_argument('--seed', type=int, default=42)
p.add_argument('--verbose', type=int, default=1)
p.add_argument('--prefill-only', action='store_true')
p.add_argument('--ab-compare', action='store_true', help='A/B compare fused vs unfused P4 for first 3 layers')
p.add_argument('--no-fused-rmsnorm', action='store_true', help='Disable P4 fused RMSNorm+quantize (use unfused path)')
p.add_argument('--warmup-gsa', action='store_true', help='Fix gsa values after first decode step (eliminates amax kernel launches)')
p.add_argument('--profile', action='store_true', help='Profile per-component GPU time using CUDA events')
p.add_argument('--num-gpus', type=int, default=8)
@@ -196,7 +198,8 @@ class CUDAGraphDecoder:
kv_caches[li], positions, token_id,
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li)
prod_lin=prod_lins.get(li),
_use_fused_rmsnorm_quantize=True
)
# Copy output to fixed buffer
self.x_out_bufs[li].copy_(X_out)
@@ -852,11 +855,70 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
compressor=None, indexer=None,
moe_runner=None, se_runner=None, router=None,
prod_lin=None, _profile_detail=False, _profile_times=None,
_use_fused_rmsnorm_quantize=True):
_use_fused_rmsnorm_quantize=True,
_ab_compare=False):
"""Forward one transformer layer.
_ab_compare: if True, run BOTH fused and unfused paths for this layer
and print detailed numerical comparison. Only use for first few layers.
"""
# P4: Fused RMSNorm + NVFP4 quantize — eliminates ~488 launches/token
from dsv4.ops.quantize import rmsnorm_quantize_nvfp4, QuantizedActivation, dequantize_nvfp4
x_in, ctx_a = attn_mhc.pre_block(X_l)
if _use_fused_rmsnorm_quantize:
# A/B comparison mode: run BOTH paths, compare intermediate results
if _ab_compare and _use_fused_rmsnorm_quantize:
# --- FUSED PATH ---
x_quant_fused = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32))
x_normed_fused = dequantize_nvfp4(x_quant_fused.x_fp4, x_quant_fused.x_sf, x_quant_fused.gsa)
# --- UNFUSED PATH ---
x_normed_unfused = rmsnorm(x_in, attn_norm_w)
# Quantize unfused x_normed the normal way (as run() would)
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4_unf, x_sf_unf, gsa_unf = quantize_nvfp4_gpu_fused(x_normed_unfused)
# Compare x_normed
cos_normed = torch.nn.functional.cosine_similarity(
x_normed_fused.flatten().float(), x_normed_unfused.flatten().float(), dim=0).item()
print(f" L{li} A/B: x_normed |fused|={x_normed_fused.abs().max().item():.6f} "
f"|unfused|={x_normed_unfused.abs().max().item():.6f} cos={cos_normed:.6f}", flush=True)
# Compare gsa values
gsa_fused_val = x_quant_fused.gsa[0].item()
gsa_unfused_val = gsa_unf[0].item()
print(f" L{li} A/B: gsa fused={gsa_fused_val:.8f} unfused={gsa_unfused_val:.8f} "
f"ratio={gsa_fused_val/max(gsa_unfused_val,1e-12):.8f}", flush=True)
# Compare FP4 data (should be different due to different intermediate precision)
fp4_match = torch.equal(x_quant_fused.x_fp4.view(torch.uint8), x_fp4_unf.view(torch.uint8))
sf_match = torch.equal(x_quant_fused.x_sf.view(torch.uint8), x_sf_unf.view(torch.uint8))
print(f" L{li} A/B: fp4_identical={fp4_match} sf_identical={sf_match}", flush=True)
# Compare block scales
sf_diff = (x_quant_fused.x_sf.view(torch.uint8).float() - x_sf_unf.view(torch.uint8).float()).abs()
print(f" L{li} A/B: sf_diff max={sf_diff.max().item():.0f} mean={sf_diff.mean().item():.2f}", flush=True)
# Run BOTH GEMM paths and compare q_a output
q_a_fused = prod_lin['q_a'].run_from_quantized(x_quant_fused)
q_a_unfused = prod_lin['q_a'](x_normed_unfused)
cos_qa = torch.nn.functional.cosine_similarity(
q_a_fused.flatten().float(), q_a_unfused.flatten().float(), dim=0).item()
print(f" L{li} A/B: q_a |fused|={q_a_fused.abs().max().item():.6f} "
f"|unfused|={q_a_unfused.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
# Run BOTH GEMM paths for kv
kv_fused = prod_lin['kv'].run_from_quantized(x_quant_fused)
kv_unfused = prod_lin['kv'](x_normed_unfused)
cos_kv = torch.nn.functional.cosine_similarity(
kv_fused.flatten().float(), kv_unfused.flatten().float(), dim=0).item()
print(f" L{li} A/B: kv |fused|={kv_fused.abs().max().item():.6f} "
f"|unfused|={kv_unfused.abs().max().item():.6f} cos={cos_kv:.6f}", flush=True)
# Now continue with the UNFUSED path (which we know works)
# to see if the rest of the layer also diverges
x_normed = x_normed_unfused
x_quant_attn = None
elif _use_fused_rmsnorm_quantize:
x_quant_attn = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32))
# Dequantize for compressor/indexer (1 kernel launch)
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
@@ -1311,7 +1373,9 @@ def main():
kv_caches[li], pre_pos_buf, pre_tid32_buf,
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li))
prod_lin=prod_lins.get(li),
_use_fused_rmsnorm_quantize=not _args.no_fused_rmsnorm,
_ab_compare=_args.ab_compare and li < 3)
except Exception as e:
torch.cuda.synchronize()
err = torch.cuda.current_stream(gpu).query()
@@ -1390,7 +1454,9 @@ def main():
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li),
_profile_detail=(profile and step == 1),
_profile_times=cuda_layer_events if (profile and step == 1) else None)
_profile_times=cuda_layer_events if (profile and step == 1) else None,
_use_fused_rmsnorm_quantize=not _args.no_fused_rmsnorm,
_ab_compare=_args.ab_compare and li < 3)
X = X.to('cuda:0'); torch.cuda.set_device(0)
t_layers = time.perf_counter()