The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision. Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision zone where step size is 25%. This makes model output garbage despite finite values. Fix: keep block scales as original float8, return global scales separately as float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate gate/up global scales, use gate_gs as alpha and apply up_correction ratio to the up half post-GEMM. weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf) nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs kernel.py: per_expert_alpha parameter in grouped GEMM deepseek_v4.py: updated type hints and comments
67 lines
2.8 KiB
Python
67 lines
2.8 KiB
Python
"""
|
|
Diagnostic: Check global scale folding precision for NVFP4 weights.
|
|
|
|
The fold is: sf_f32 * gs → clamp(0, 448) → float8_e4m3fn
|
|
Question: how much precision is lost in the float8 round-trip?
|
|
"""
|
|
import torch
|
|
|
|
# Simulate typical NVFP4 scale distributions
|
|
# block_scale (float8_e4m3fn) range: roughly 0.06 to 448
|
|
# global_scale (float32) range: varies per expert
|
|
|
|
# Test 1: If global_scale >> 1, product can exceed 448 → clamp → loss
|
|
# Test 2: If global_scale << 1, product can go subnormal → loss
|
|
# Test 3: Quantization error from 3-bit mantissa
|
|
|
|
# Simulate a range of scale values
|
|
block_scales = torch.tensor([0.0625, 0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 448.0], dtype=torch.float32)
|
|
global_scales = torch.tensor([0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0], dtype=torch.float32)
|
|
|
|
print("=== Float8 Folding Precision Analysis ===\n")
|
|
print(f"block_scales: {block_scales.tolist()}")
|
|
print(f"global_scales: {global_scales.tolist()}\n")
|
|
|
|
total_clamped = 0
|
|
total_subnormal = 0
|
|
max_rel_error = 0.0
|
|
|
|
for gs in global_scales:
|
|
products = block_scales * gs
|
|
clamped = products.clamp(0.0, 448.0)
|
|
folded_f8 = clamped.to(torch.float8_e4m3fn)
|
|
roundtrip = folded_f8.to(torch.float32)
|
|
|
|
n_clamped = (products > 448.0).sum().item()
|
|
n_subnormal = (roundtrip > 0).logical_and(roundtrip < 0.0625).sum().item() # rough check
|
|
|
|
rel_errors = torch.where(roundtrip > 0, (roundtrip - clamped).abs() / clamped.clamp(min=1e-10), torch.zeros_like(clamped))
|
|
max_err = rel_errors.max().item()
|
|
|
|
total_clamped += n_clamped
|
|
total_subnormal += n_subnormal
|
|
max_rel_error = max(max_rel_error, max_err)
|
|
|
|
if n_clamped > 0 or max_err > 0.05:
|
|
print(f"gs={gs:.3f}: {n_clamped} clamped, max_rel_err={max_err:.4f}")
|
|
for i, (p, c, r) in enumerate(zip(products, clamped, roundtrip)):
|
|
if abs(r - c) / max(abs(c), 1e-10) > 0.01:
|
|
print(f" block={block_scales[i]:.4f} product={p:.4f} clamped={c:.4f} roundtrip={r:.4f} err={abs(r-c)/max(abs(c),1e-10):.4f}")
|
|
|
|
print(f"\nTotal clamped: {total_clamped}, Total subnormal: {total_subnormal}, Max relative error: {max_rel_error:.4f}")
|
|
|
|
# The real check: what's the float8_e4m3fn step size at various magnitudes?
|
|
print("\n=== Float8 E4M3 Step Sizes ===")
|
|
test_vals = [0.01, 0.1, 1.0, 10.0, 100.0, 448.0]
|
|
for v in test_vals:
|
|
f8 = torch.tensor(v, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
back = f8.to(torch.float32)
|
|
# Find next representable value
|
|
u8 = f8.view(torch.uint8)
|
|
next_u8 = u8 + 1
|
|
next_f8 = next_u8.view(torch.float8_e4m3fn)
|
|
next_val = next_f8.to(torch.float32)
|
|
step = next_val - back
|
|
rel_step = step / back if back > 0 else 0
|
|
print(f" value={v:.3f} → f8={back:.6f} → next={next_val:.6f} step={step:.6f} rel={rel_step:.4f}")
|