From eb5ef93bf1390d4129d5eb91b19ccde29d49ec53 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 18:49:30 +0000 Subject: [PATCH] Add A/B comparison mode for P4 fused vs unfused RMSNorm+quantize - Added --ab-compare flag to run both fused and unfused paths for first 3 layers - Compares x_normed, gsa values, FP4 data, and GEMM outputs (q_a, kv) - Added --no-fused-rmsnorm to disable P4 and use unfused path - This will help diagnose the correctness regression introduced by P4 --- NEXT_SESSION.md | 35 ++++++++++++++++++ PERFORMANCE_AUDIT.md | 33 ++++++++++++++++- single_shot_inference.py | 76 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 138 insertions(+), 6 deletions(-) create mode 100644 NEXT_SESSION.md diff --git a/NEXT_SESSION.md b/NEXT_SESSION.md new file mode 100644 index 00000000..998673c0 --- /dev/null +++ b/NEXT_SESSION.md @@ -0,0 +1,35 @@ +Here's the summary for next session: +## Summary: What Was Done & What to Investigate + +### What was done today +1. **P4 — Fused RMSNorm + NVFP4 quantize**: New CUDA kernel (2 launches vs 6+). Integrated into `single_shot_inference.py` for the attention path. Unit test passes (cos=0.996-0.999 vs unfused). Decode speed improved 0.53→0.43s/token. + +2. **P5 — Fused mHC + RMSNorm + NVFP4 quantize**: New CUDA kernel (2 launches vs 7+). Unit test passes on B200 (cos=0.996-0.999). NOT yet integrated into single_shot — needs moe_forward refactoring. + +3. **CRITICAL FIX — Indexer deadlock**: `__syncthreads()` inside a strided loop caused deadlock at production context lengths (always deadlocked when `num_valid % 128 != 0`). Fixed with per-thread local top-k + block-level merge. Both copies updated. + +### The regression +After P4 integration, model output went from coherent English to garbled nonsense. This is a NEW regression — output was correct last night after P0-P3 + KV work. + +### What to investigate (ordered by likelihood) + +1. **`run_from_quantized` gsa shape bug (MOST LIKELY)**: The fused kernel produces per-row gsa (shape `(M,)`). `run_from_quantized` passes this to the CuTeDSL NVFP4 GEMM as `global_scale_a`. But the GEMM expects a **single scalar** — it's one scale for the entire A matrix. For M=1 decode, shape `(1,)` may work as a scalar, BUT the gsa VALUE differs from the unfused path's scalar gsa because: + - Unfused: quantize computes gsa from the full (M, N) tensor — single scalar + - Fused: computes gsa per-row, then `[:1].reshape(1)` takes only the first row's gsa + - These differ when rows have different magnitudes + +2. **Dequant→requant noise for compressor**: The compressor gets `x_normed` by dequantizing the fused kernel's FP4 output. This introduces ~0.5% quantization error (cos=0.994) that wasn't present before. This noise propagates into compression scores and indexer queries. + +3. **A/B test first**: Run with `_use_fused_rmsnorm_quantize=False` to confirm P4 is the cause. If output is correct with unfused, the bug is in the P4 path. + +### Key code paths +- **`dsv4/layers/linear.py`**: `run_from_quantized()` — the gsa handling is suspicious +- **`dsv4/ops/gemm_runner.py`**: `run_nvfp4_grouped_gemm()` — how global_scale_a is consumed +- **`dsv4/kernels/gemm/grouped.py`**: CuTeDSL GEMM — global_scale_a is scalar +- **`single_shot_inference.py:855-870`**: P4 integration point +- **`dsv4/kernels/cuda/fused_rmsnorm_quantize.cu`**: the fused kernel + +### Quick fix ideas +- Use **scalar gsa** (reduce per-row gsa to a single max) for the GEMM — keeps the fusion but makes gsa compatible +- Or: don't use `run_from_quantized` — just use the fused kernel for the BF16 output (dequant) and let each linear re-quantize with its own scalar gsa (saves rmsnorm launches but not quantize launches) +- Or: fix the CuTeDSL GEMM to support per-row global_scale_a (THIS IS COMPLEX, BUT IF IT IS THE CORRECT WAY OF DOING THINGS. I WANT IT DONE THAT WAY!!!!!) \ No newline at end of file diff --git a/PERFORMANCE_AUDIT.md b/PERFORMANCE_AUDIT.md index 28e9b480..d3de8902 100644 --- a/PERFORMANCE_AUDIT.md +++ b/PERFORMANCE_AUDIT.md @@ -120,7 +120,7 @@ FP32 cos/sin cache, forward + inverse, in-place operation. Total BF16 at 1M context: ~10 GB on 8×B200. Fits comfortably, so **KV quantization is a throughput question, not a memory question.** -## Why FP4 storage is the right answer for the compressed streams +## Why FP4 storage is the right answer for the compressed streams - THIS IS NOT WHAT WE ENDED UP USING BECAUSE THE COSINE WAS TOO FAR OFF, Three reasons, in priority order: @@ -168,6 +168,37 @@ when E7 lands). - **Recall@k for indexer ≥ 99% vs FP32 oracle** (the bar from the prior indexer-fix audit). Critical — FP4 must not corrupt top-k ranking. +### THE ABOVE DID NOT WORK... WHY NOT NVFP4 (native Blackwell FP4)? + ───────────────────────────────────── + We *really* wanted to use NVFP4 (E2M1 + E4M3 block scales + FP32 global scale) + for compressed KV storage. Blackwell's native FP4→MMA path would have given us + 3.5× memory savings and direct tensor-core consumption — the dream pipeline. + We tried. Hard. Three separate approaches: + 1. Fused compressor_reduce_quant.cu — single-kernel compress→NVFP4. Bugs in + cross-warp block amax reduction and shared memory corruption (s_scratch + stomping adjacent variables). Best cos=0.703. Dead. + 2. Proven two-kernel path (amax_gsa → quantize_from_buffer) using kv_quantize.cu's + compute_amax_gsa_fp32 + quantize_nvfp4_from_fp32. cos=0.995 on random data, + but that's the *quantize/dequant* round-trip in isolation. In the full pipeline, + the 4-bit precision on 448 non-RoPE dimensions accumulated error across 61 layers + of mHC — residual |X| already grows to 300-500, and NVFP4's 16-element block + quantization (4.5 bits effective) added ~0.5% per layer on top of that. + 3. FP32 RoPE kernel (rope_fp32 in kv_quantize.cu) to avoid BF16 RoPE intermediate. + Had an indexing bug (cos=0.977 for M>1). Fixed but the real issue was NVFP4, + not RoPE. + The verdict: NVFP4's 4.5 effective bits per element is simply too coarse for + compressed KV values that get summed in attention softmax. FP8_E4M3's 5.3 effective + bits gives cos=0.9997 round-trip (vs NVFP4's 0.995) — that 0.4% difference compounds + fatally across 61 layers. + + +We settled on FP8_E4M3 for non-RoPE + BF16 for RoPE — exactly what DeepSeek V4 +ships in production!!!!!!!! Not because we couldn't build the NVFP4 path (we did, it compiled +and ran), but because the math didn't hold up. Sometimes 4 bits isn't enough. +If Blackwell adds a finer-grained FP4 variant (8-element blocks, 6 effective bits), +revisit this. The kernels exist. The quantize/dequant path is proven. The precision +just isn't there yet for attention-sensitive KV values. + --- # PART 3 — OTHER FUSION WINS, RANKED BY EFFORT/IMPACT diff --git a/single_shot_inference.py b/single_shot_inference.py index 68e9441d..787cd391 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -27,6 +27,8 @@ def parse_args(): p.add_argument('--seed', type=int, default=42) p.add_argument('--verbose', type=int, default=1) p.add_argument('--prefill-only', action='store_true') + p.add_argument('--ab-compare', action='store_true', help='A/B compare fused vs unfused P4 for first 3 layers') + p.add_argument('--no-fused-rmsnorm', action='store_true', help='Disable P4 fused RMSNorm+quantize (use unfused path)') p.add_argument('--warmup-gsa', action='store_true', help='Fix gsa values after first decode step (eliminates amax kernel launches)') p.add_argument('--profile', action='store_true', help='Profile per-component GPU time using CUDA events') p.add_argument('--num-gpus', type=int, default=8) @@ -196,7 +198,8 @@ class CUDAGraphDecoder: kv_caches[li], positions, token_id, compressors.get(li), indexers.get(li), moe_runners.get(li), se_runners.get(li), routers.get(li), - prod_lin=prod_lins.get(li) + prod_lin=prod_lins.get(li), + _use_fused_rmsnorm_quantize=True ) # Copy output to fixed buffer self.x_out_bufs[li].copy_(X_out) @@ -852,11 +855,70 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, compressor=None, indexer=None, moe_runner=None, se_runner=None, router=None, prod_lin=None, _profile_detail=False, _profile_times=None, - _use_fused_rmsnorm_quantize=True): + _use_fused_rmsnorm_quantize=True, + _ab_compare=False): + """Forward one transformer layer. + + _ab_compare: if True, run BOTH fused and unfused paths for this layer + and print detailed numerical comparison. Only use for first few layers. + """ # P4: Fused RMSNorm + NVFP4 quantize — eliminates ~488 launches/token from dsv4.ops.quantize import rmsnorm_quantize_nvfp4, QuantizedActivation, dequantize_nvfp4 x_in, ctx_a = attn_mhc.pre_block(X_l) - if _use_fused_rmsnorm_quantize: + + # A/B comparison mode: run BOTH paths, compare intermediate results + if _ab_compare and _use_fused_rmsnorm_quantize: + # --- FUSED PATH --- + x_quant_fused = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32)) + x_normed_fused = dequantize_nvfp4(x_quant_fused.x_fp4, x_quant_fused.x_sf, x_quant_fused.gsa) + # --- UNFUSED PATH --- + x_normed_unfused = rmsnorm(x_in, attn_norm_w) + # Quantize unfused x_normed the normal way (as run() would) + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + x_fp4_unf, x_sf_unf, gsa_unf = quantize_nvfp4_gpu_fused(x_normed_unfused) + + # Compare x_normed + cos_normed = torch.nn.functional.cosine_similarity( + x_normed_fused.flatten().float(), x_normed_unfused.flatten().float(), dim=0).item() + print(f" L{li} A/B: x_normed |fused|={x_normed_fused.abs().max().item():.6f} " + f"|unfused|={x_normed_unfused.abs().max().item():.6f} cos={cos_normed:.6f}", flush=True) + + # Compare gsa values + gsa_fused_val = x_quant_fused.gsa[0].item() + gsa_unfused_val = gsa_unf[0].item() + print(f" L{li} A/B: gsa fused={gsa_fused_val:.8f} unfused={gsa_unfused_val:.8f} " + f"ratio={gsa_fused_val/max(gsa_unfused_val,1e-12):.8f}", flush=True) + + # Compare FP4 data (should be different due to different intermediate precision) + fp4_match = torch.equal(x_quant_fused.x_fp4.view(torch.uint8), x_fp4_unf.view(torch.uint8)) + sf_match = torch.equal(x_quant_fused.x_sf.view(torch.uint8), x_sf_unf.view(torch.uint8)) + print(f" L{li} A/B: fp4_identical={fp4_match} sf_identical={sf_match}", flush=True) + + # Compare block scales + sf_diff = (x_quant_fused.x_sf.view(torch.uint8).float() - x_sf_unf.view(torch.uint8).float()).abs() + print(f" L{li} A/B: sf_diff max={sf_diff.max().item():.0f} mean={sf_diff.mean().item():.2f}", flush=True) + + # Run BOTH GEMM paths and compare q_a output + q_a_fused = prod_lin['q_a'].run_from_quantized(x_quant_fused) + q_a_unfused = prod_lin['q_a'](x_normed_unfused) + cos_qa = torch.nn.functional.cosine_similarity( + q_a_fused.flatten().float(), q_a_unfused.flatten().float(), dim=0).item() + print(f" L{li} A/B: q_a |fused|={q_a_fused.abs().max().item():.6f} " + f"|unfused|={q_a_unfused.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True) + + # Run BOTH GEMM paths for kv + kv_fused = prod_lin['kv'].run_from_quantized(x_quant_fused) + kv_unfused = prod_lin['kv'](x_normed_unfused) + cos_kv = torch.nn.functional.cosine_similarity( + kv_fused.flatten().float(), kv_unfused.flatten().float(), dim=0).item() + print(f" L{li} A/B: kv |fused|={kv_fused.abs().max().item():.6f} " + f"|unfused|={kv_unfused.abs().max().item():.6f} cos={cos_kv:.6f}", flush=True) + + # Now continue with the UNFUSED path (which we know works) + # to see if the rest of the layer also diverges + x_normed = x_normed_unfused + x_quant_attn = None + elif _use_fused_rmsnorm_quantize: x_quant_attn = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32)) # Dequantize for compressor/indexer (1 kernel launch) x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa) @@ -1311,7 +1373,9 @@ def main(): kv_caches[li], pre_pos_buf, pre_tid32_buf, compressors.get(li), indexers.get(li), moe_runners.get(li), se_runners.get(li), routers.get(li), - prod_lin=prod_lins.get(li)) + prod_lin=prod_lins.get(li), + _use_fused_rmsnorm_quantize=not _args.no_fused_rmsnorm, + _ab_compare=_args.ab_compare and li < 3) except Exception as e: torch.cuda.synchronize() err = torch.cuda.current_stream(gpu).query() @@ -1390,7 +1454,9 @@ def main(): moe_runners.get(li), se_runners.get(li), routers.get(li), prod_lin=prod_lins.get(li), _profile_detail=(profile and step == 1), - _profile_times=cuda_layer_events if (profile and step == 1) else None) + _profile_times=cuda_layer_events if (profile and step == 1) else None, + _use_fused_rmsnorm_quantize=not _args.no_fused_rmsnorm, + _ab_compare=_args.ab_compare and li < 3) X = X.to('cuda:0'); torch.cuda.set_device(0) t_layers = time.perf_counter()