Files
nvfp4-megamoe-kernel/diag_fold.py
biondizzle fd59222fc0 fix: stop folding global scale into float8 block scales
The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision.
Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision
zone where step size is 25%. This makes model output garbage despite finite values.

Fix: keep block scales as original float8, return global scales separately as
float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in
cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate
gate/up global scales, use gate_gs as alpha and apply up_correction ratio to
the up half post-GEMM.

weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf)
nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs
kernel.py: per_expert_alpha parameter in grouped GEMM
deepseek_v4.py: updated type hints and comments
2026-05-15 12:42:53 +00:00

67 lines
2.8 KiB
Python

"""
Diagnostic: Check global scale folding precision for NVFP4 weights.
The fold is: sf_f32 * gs → clamp(0, 448) → float8_e4m3fn
Question: how much precision is lost in the float8 round-trip?
"""
import torch
# Simulate typical NVFP4 scale distributions
# block_scale (float8_e4m3fn) range: roughly 0.06 to 448
# global_scale (float32) range: varies per expert
# Test 1: If global_scale >> 1, product can exceed 448 → clamp → loss
# Test 2: If global_scale << 1, product can go subnormal → loss
# Test 3: Quantization error from 3-bit mantissa
# Simulate a range of scale values
block_scales = torch.tensor([0.0625, 0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 448.0], dtype=torch.float32)
global_scales = torch.tensor([0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0], dtype=torch.float32)
print("=== Float8 Folding Precision Analysis ===\n")
print(f"block_scales: {block_scales.tolist()}")
print(f"global_scales: {global_scales.tolist()}\n")
total_clamped = 0
total_subnormal = 0
max_rel_error = 0.0
for gs in global_scales:
products = block_scales * gs
clamped = products.clamp(0.0, 448.0)
folded_f8 = clamped.to(torch.float8_e4m3fn)
roundtrip = folded_f8.to(torch.float32)
n_clamped = (products > 448.0).sum().item()
n_subnormal = (roundtrip > 0).logical_and(roundtrip < 0.0625).sum().item() # rough check
rel_errors = torch.where(roundtrip > 0, (roundtrip - clamped).abs() / clamped.clamp(min=1e-10), torch.zeros_like(clamped))
max_err = rel_errors.max().item()
total_clamped += n_clamped
total_subnormal += n_subnormal
max_rel_error = max(max_rel_error, max_err)
if n_clamped > 0 or max_err > 0.05:
print(f"gs={gs:.3f}: {n_clamped} clamped, max_rel_err={max_err:.4f}")
for i, (p, c, r) in enumerate(zip(products, clamped, roundtrip)):
if abs(r - c) / max(abs(c), 1e-10) > 0.01:
print(f" block={block_scales[i]:.4f} product={p:.4f} clamped={c:.4f} roundtrip={r:.4f} err={abs(r-c)/max(abs(c),1e-10):.4f}")
print(f"\nTotal clamped: {total_clamped}, Total subnormal: {total_subnormal}, Max relative error: {max_rel_error:.4f}")
# The real check: what's the float8_e4m3fn step size at various magnitudes?
print("\n=== Float8 E4M3 Step Sizes ===")
test_vals = [0.01, 0.1, 1.0, 10.0, 100.0, 448.0]
for v in test_vals:
f8 = torch.tensor(v, dtype=torch.float32).to(torch.float8_e4m3fn)
back = f8.to(torch.float32)
# Find next representable value
u8 = f8.view(torch.uint8)
next_u8 = u8 + 1
next_f8 = next_u8.view(torch.float8_e4m3fn)
next_val = next_f8.to(torch.float32)
step = next_val - back
rel_step = step / back if back > 0 else 0
print(f" value={v:.3f} → f8={back:.6f} → next={next_val:.6f} step={step:.6f} rel={rel_step:.4f}")