Add A/B comparison mode for P4 fused vs unfused RMSNorm+quantize
- Added --ab-compare flag to run both fused and unfused paths for first 3 layers - Compares x_normed, gsa values, FP4 data, and GEMM outputs (q_a, kv) - Added --no-fused-rmsnorm to disable P4 and use unfused path - This will help diagnose the correctness regression introduced by P4
This commit is contained in:
35
NEXT_SESSION.md
Normal file
35
NEXT_SESSION.md
Normal file
@@ -0,0 +1,35 @@
|
||||
Here's the summary for next session:
|
||||
## Summary: What Was Done & What to Investigate
|
||||
|
||||
### What was done today
|
||||
1. **P4 — Fused RMSNorm + NVFP4 quantize**: New CUDA kernel (2 launches vs 6+). Integrated into `single_shot_inference.py` for the attention path. Unit test passes (cos=0.996-0.999 vs unfused). Decode speed improved 0.53→0.43s/token.
|
||||
|
||||
2. **P5 — Fused mHC + RMSNorm + NVFP4 quantize**: New CUDA kernel (2 launches vs 7+). Unit test passes on B200 (cos=0.996-0.999). NOT yet integrated into single_shot — needs moe_forward refactoring.
|
||||
|
||||
3. **CRITICAL FIX — Indexer deadlock**: `__syncthreads()` inside a strided loop caused deadlock at production context lengths (always deadlocked when `num_valid % 128 != 0`). Fixed with per-thread local top-k + block-level merge. Both copies updated.
|
||||
|
||||
### The regression
|
||||
After P4 integration, model output went from coherent English to garbled nonsense. This is a NEW regression — output was correct last night after P0-P3 + KV work.
|
||||
|
||||
### What to investigate (ordered by likelihood)
|
||||
|
||||
1. **`run_from_quantized` gsa shape bug (MOST LIKELY)**: The fused kernel produces per-row gsa (shape `(M,)`). `run_from_quantized` passes this to the CuTeDSL NVFP4 GEMM as `global_scale_a`. But the GEMM expects a **single scalar** — it's one scale for the entire A matrix. For M=1 decode, shape `(1,)` may work as a scalar, BUT the gsa VALUE differs from the unfused path's scalar gsa because:
|
||||
- Unfused: quantize computes gsa from the full (M, N) tensor — single scalar
|
||||
- Fused: computes gsa per-row, then `[:1].reshape(1)` takes only the first row's gsa
|
||||
- These differ when rows have different magnitudes
|
||||
|
||||
2. **Dequant→requant noise for compressor**: The compressor gets `x_normed` by dequantizing the fused kernel's FP4 output. This introduces ~0.5% quantization error (cos=0.994) that wasn't present before. This noise propagates into compression scores and indexer queries.
|
||||
|
||||
3. **A/B test first**: Run with `_use_fused_rmsnorm_quantize=False` to confirm P4 is the cause. If output is correct with unfused, the bug is in the P4 path.
|
||||
|
||||
### Key code paths
|
||||
- **`dsv4/layers/linear.py`**: `run_from_quantized()` — the gsa handling is suspicious
|
||||
- **`dsv4/ops/gemm_runner.py`**: `run_nvfp4_grouped_gemm()` — how global_scale_a is consumed
|
||||
- **`dsv4/kernels/gemm/grouped.py`**: CuTeDSL GEMM — global_scale_a is scalar
|
||||
- **`single_shot_inference.py:855-870`**: P4 integration point
|
||||
- **`dsv4/kernels/cuda/fused_rmsnorm_quantize.cu`**: the fused kernel
|
||||
|
||||
### Quick fix ideas
|
||||
- Use **scalar gsa** (reduce per-row gsa to a single max) for the GEMM — keeps the fusion but makes gsa compatible
|
||||
- Or: don't use `run_from_quantized` — just use the fused kernel for the BF16 output (dequant) and let each linear re-quantize with its own scalar gsa (saves rmsnorm launches but not quantize launches)
|
||||
- Or: fix the CuTeDSL GEMM to support per-row global_scale_a (THIS IS COMPLEX, BUT IF IT IS THE CORRECT WAY OF DOING THINGS. I WANT IT DONE THAT WAY!!!!!)
|
||||
@@ -120,7 +120,7 @@ FP32 cos/sin cache, forward + inverse, in-place operation.
|
||||
Total BF16 at 1M context: ~10 GB on 8×B200. Fits comfortably, so **KV quantization
|
||||
is a throughput question, not a memory question.**
|
||||
|
||||
## Why FP4 storage is the right answer for the compressed streams
|
||||
## Why FP4 storage is the right answer for the compressed streams - THIS IS NOT WHAT WE ENDED UP USING BECAUSE THE COSINE WAS TOO FAR OFF,
|
||||
|
||||
Three reasons, in priority order:
|
||||
|
||||
@@ -168,6 +168,37 @@ when E7 lands).
|
||||
- **Recall@k for indexer ≥ 99% vs FP32 oracle** (the bar from the prior
|
||||
indexer-fix audit). Critical — FP4 must not corrupt top-k ranking.
|
||||
|
||||
### THE ABOVE DID NOT WORK... WHY NOT NVFP4 (native Blackwell FP4)?
|
||||
─────────────────────────────────────
|
||||
We *really* wanted to use NVFP4 (E2M1 + E4M3 block scales + FP32 global scale)
|
||||
for compressed KV storage. Blackwell's native FP4→MMA path would have given us
|
||||
3.5× memory savings and direct tensor-core consumption — the dream pipeline.
|
||||
We tried. Hard. Three separate approaches:
|
||||
1. Fused compressor_reduce_quant.cu — single-kernel compress→NVFP4. Bugs in
|
||||
cross-warp block amax reduction and shared memory corruption (s_scratch
|
||||
stomping adjacent variables). Best cos=0.703. Dead.
|
||||
2. Proven two-kernel path (amax_gsa → quantize_from_buffer) using kv_quantize.cu's
|
||||
compute_amax_gsa_fp32 + quantize_nvfp4_from_fp32. cos=0.995 on random data,
|
||||
but that's the *quantize/dequant* round-trip in isolation. In the full pipeline,
|
||||
the 4-bit precision on 448 non-RoPE dimensions accumulated error across 61 layers
|
||||
of mHC — residual |X| already grows to 300-500, and NVFP4's 16-element block
|
||||
quantization (4.5 bits effective) added ~0.5% per layer on top of that.
|
||||
3. FP32 RoPE kernel (rope_fp32 in kv_quantize.cu) to avoid BF16 RoPE intermediate.
|
||||
Had an indexing bug (cos=0.977 for M>1). Fixed but the real issue was NVFP4,
|
||||
not RoPE.
|
||||
The verdict: NVFP4's 4.5 effective bits per element is simply too coarse for
|
||||
compressed KV values that get summed in attention softmax. FP8_E4M3's 5.3 effective
|
||||
bits gives cos=0.9997 round-trip (vs NVFP4's 0.995) — that 0.4% difference compounds
|
||||
fatally across 61 layers.
|
||||
|
||||
|
||||
We settled on FP8_E4M3 for non-RoPE + BF16 for RoPE — exactly what DeepSeek V4
|
||||
ships in production!!!!!!!! Not because we couldn't build the NVFP4 path (we did, it compiled
|
||||
and ran), but because the math didn't hold up. Sometimes 4 bits isn't enough.
|
||||
If Blackwell adds a finer-grained FP4 variant (8-element blocks, 6 effective bits),
|
||||
revisit this. The kernels exist. The quantize/dequant path is proven. The precision
|
||||
just isn't there yet for attention-sensitive KV values.
|
||||
|
||||
---
|
||||
|
||||
# PART 3 — OTHER FUSION WINS, RANKED BY EFFORT/IMPACT
|
||||
|
||||
@@ -27,6 +27,8 @@ def parse_args():
|
||||
p.add_argument('--seed', type=int, default=42)
|
||||
p.add_argument('--verbose', type=int, default=1)
|
||||
p.add_argument('--prefill-only', action='store_true')
|
||||
p.add_argument('--ab-compare', action='store_true', help='A/B compare fused vs unfused P4 for first 3 layers')
|
||||
p.add_argument('--no-fused-rmsnorm', action='store_true', help='Disable P4 fused RMSNorm+quantize (use unfused path)')
|
||||
p.add_argument('--warmup-gsa', action='store_true', help='Fix gsa values after first decode step (eliminates amax kernel launches)')
|
||||
p.add_argument('--profile', action='store_true', help='Profile per-component GPU time using CUDA events')
|
||||
p.add_argument('--num-gpus', type=int, default=8)
|
||||
@@ -196,7 +198,8 @@ class CUDAGraphDecoder:
|
||||
kv_caches[li], positions, token_id,
|
||||
compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li)
|
||||
prod_lin=prod_lins.get(li),
|
||||
_use_fused_rmsnorm_quantize=True
|
||||
)
|
||||
# Copy output to fixed buffer
|
||||
self.x_out_bufs[li].copy_(X_out)
|
||||
@@ -852,11 +855,70 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
compressor=None, indexer=None,
|
||||
moe_runner=None, se_runner=None, router=None,
|
||||
prod_lin=None, _profile_detail=False, _profile_times=None,
|
||||
_use_fused_rmsnorm_quantize=True):
|
||||
_use_fused_rmsnorm_quantize=True,
|
||||
_ab_compare=False):
|
||||
"""Forward one transformer layer.
|
||||
|
||||
_ab_compare: if True, run BOTH fused and unfused paths for this layer
|
||||
and print detailed numerical comparison. Only use for first few layers.
|
||||
"""
|
||||
# P4: Fused RMSNorm + NVFP4 quantize — eliminates ~488 launches/token
|
||||
from dsv4.ops.quantize import rmsnorm_quantize_nvfp4, QuantizedActivation, dequantize_nvfp4
|
||||
x_in, ctx_a = attn_mhc.pre_block(X_l)
|
||||
if _use_fused_rmsnorm_quantize:
|
||||
|
||||
# A/B comparison mode: run BOTH paths, compare intermediate results
|
||||
if _ab_compare and _use_fused_rmsnorm_quantize:
|
||||
# --- FUSED PATH ---
|
||||
x_quant_fused = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32))
|
||||
x_normed_fused = dequantize_nvfp4(x_quant_fused.x_fp4, x_quant_fused.x_sf, x_quant_fused.gsa)
|
||||
# --- UNFUSED PATH ---
|
||||
x_normed_unfused = rmsnorm(x_in, attn_norm_w)
|
||||
# Quantize unfused x_normed the normal way (as run() would)
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
x_fp4_unf, x_sf_unf, gsa_unf = quantize_nvfp4_gpu_fused(x_normed_unfused)
|
||||
|
||||
# Compare x_normed
|
||||
cos_normed = torch.nn.functional.cosine_similarity(
|
||||
x_normed_fused.flatten().float(), x_normed_unfused.flatten().float(), dim=0).item()
|
||||
print(f" L{li} A/B: x_normed |fused|={x_normed_fused.abs().max().item():.6f} "
|
||||
f"|unfused|={x_normed_unfused.abs().max().item():.6f} cos={cos_normed:.6f}", flush=True)
|
||||
|
||||
# Compare gsa values
|
||||
gsa_fused_val = x_quant_fused.gsa[0].item()
|
||||
gsa_unfused_val = gsa_unf[0].item()
|
||||
print(f" L{li} A/B: gsa fused={gsa_fused_val:.8f} unfused={gsa_unfused_val:.8f} "
|
||||
f"ratio={gsa_fused_val/max(gsa_unfused_val,1e-12):.8f}", flush=True)
|
||||
|
||||
# Compare FP4 data (should be different due to different intermediate precision)
|
||||
fp4_match = torch.equal(x_quant_fused.x_fp4.view(torch.uint8), x_fp4_unf.view(torch.uint8))
|
||||
sf_match = torch.equal(x_quant_fused.x_sf.view(torch.uint8), x_sf_unf.view(torch.uint8))
|
||||
print(f" L{li} A/B: fp4_identical={fp4_match} sf_identical={sf_match}", flush=True)
|
||||
|
||||
# Compare block scales
|
||||
sf_diff = (x_quant_fused.x_sf.view(torch.uint8).float() - x_sf_unf.view(torch.uint8).float()).abs()
|
||||
print(f" L{li} A/B: sf_diff max={sf_diff.max().item():.0f} mean={sf_diff.mean().item():.2f}", flush=True)
|
||||
|
||||
# Run BOTH GEMM paths and compare q_a output
|
||||
q_a_fused = prod_lin['q_a'].run_from_quantized(x_quant_fused)
|
||||
q_a_unfused = prod_lin['q_a'](x_normed_unfused)
|
||||
cos_qa = torch.nn.functional.cosine_similarity(
|
||||
q_a_fused.flatten().float(), q_a_unfused.flatten().float(), dim=0).item()
|
||||
print(f" L{li} A/B: q_a |fused|={q_a_fused.abs().max().item():.6f} "
|
||||
f"|unfused|={q_a_unfused.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
|
||||
|
||||
# Run BOTH GEMM paths for kv
|
||||
kv_fused = prod_lin['kv'].run_from_quantized(x_quant_fused)
|
||||
kv_unfused = prod_lin['kv'](x_normed_unfused)
|
||||
cos_kv = torch.nn.functional.cosine_similarity(
|
||||
kv_fused.flatten().float(), kv_unfused.flatten().float(), dim=0).item()
|
||||
print(f" L{li} A/B: kv |fused|={kv_fused.abs().max().item():.6f} "
|
||||
f"|unfused|={kv_unfused.abs().max().item():.6f} cos={cos_kv:.6f}", flush=True)
|
||||
|
||||
# Now continue with the UNFUSED path (which we know works)
|
||||
# to see if the rest of the layer also diverges
|
||||
x_normed = x_normed_unfused
|
||||
x_quant_attn = None
|
||||
elif _use_fused_rmsnorm_quantize:
|
||||
x_quant_attn = rmsnorm_quantize_nvfp4(x_in, attn_norm_w.to(x_in.device, torch.float32))
|
||||
# Dequantize for compressor/indexer (1 kernel launch)
|
||||
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
||||
@@ -1311,7 +1373,9 @@ def main():
|
||||
kv_caches[li], pre_pos_buf, pre_tid32_buf,
|
||||
compressors.get(li), indexers.get(li),
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li))
|
||||
prod_lin=prod_lins.get(li),
|
||||
_use_fused_rmsnorm_quantize=not _args.no_fused_rmsnorm,
|
||||
_ab_compare=_args.ab_compare and li < 3)
|
||||
except Exception as e:
|
||||
torch.cuda.synchronize()
|
||||
err = torch.cuda.current_stream(gpu).query()
|
||||
@@ -1390,7 +1454,9 @@ def main():
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li),
|
||||
_profile_detail=(profile and step == 1),
|
||||
_profile_times=cuda_layer_events if (profile and step == 1) else None)
|
||||
_profile_times=cuda_layer_events if (profile and step == 1) else None,
|
||||
_use_fused_rmsnorm_quantize=not _args.no_fused_rmsnorm,
|
||||
_ab_compare=_args.ab_compare and li < 3)
|
||||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||||
t_layers = time.perf_counter()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user