Files
nvfp4-megamoe-kernel/diag_issues.py
biondizzle fd59222fc0 fix: stop folding global scale into float8 block scales
The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision.
Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision
zone where step size is 25%. This makes model output garbage despite finite values.

Fix: keep block scales as original float8, return global scales separately as
float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in
cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate
gate/up global scales, use gate_gs as alpha and apply up_correction ratio to
the up half post-GEMM.

weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf)
nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs
kernel.py: per_expert_alpha parameter in grouped GEMM
deepseek_v4.py: updated type hints and comments
2026-05-15 12:42:53 +00:00

329 lines
14 KiB
Python

"""
Diagnostic script for NVFP4 mega_moe issues.
Run on the B200 server. Checks:
1. Global scale folding precision (float8 round-trip)
2. L2 weight/SF orientation (transpose correctness)
3. EP aggregation contract (local vs all-reduce)
4. Folded scale float8 precision loss
Usage: python diag_issues.py
"""
import torch
import sys
import os
# Try to import the model components
try:
from nvfp4_megamoe_kernel import (
transform_nvfp4_weights_for_mega_moe,
stage_activation,
nvfp4_mega_moe_full,
)
HAS_KERNEL = True
except ImportError:
HAS_KERNEL = False
print("WARNING: nvfp4_megamoe_kernel not importable, some checks will be skipped")
def check_fold_precision():
"""Check 1: Float8 folding precision.
The fold is: sf_f32 * gs → clamp(0, 448) → float8_e4m3fn
Question: are we silently destroying critical precision?
"""
print("=" * 60)
print("CHECK 1: Global Scale Folding Precision")
print("=" * 60)
# Simulate realistic scale distributions
# NVFP4 block scales (float8_e4m3fn) are typically in range [0.06, 448]
# Global scales are per-expert float32
# Test with realistic ranges
for gs_val in [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]:
# Simulate 1000 block scales
sf = torch.rand(48, 64, 192) * 448 # Smaller for quantile perf
sf_f8 = sf.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
sf_back = sf_f8.to(torch.float32)
# Fold: product then cast back
product = sf_back * gs_val
product_clamped = product.clamp(0.0, 448.0)
folded_f8 = product_clamped.to(torch.float8_e4m3fn)
folded_back = folded_f8.to(torch.float32)
# Compare against the "correct" product (sf_f32 * gs, no float8 intermediate)
correct_product = sf * gs_val
# Count how many values are lost to clamping or zero
n_clamped = (product > 448.0).sum().item()
n_zeroed = (folded_back == 0.0).sum().item() - (correct_product == 0.0).sum().item()
# Relative error
rel_err = (folded_back - correct_product).abs() / correct_product.clamp(min=1e-10)
max_rel = rel_err.max().item()
mean_rel = rel_err.mean().item()
p99_rel = rel_err.quantile(0.99).item()
print(f" gs={gs_val:>8.3f}: clamped={n_clamped:>8d} zeroed={n_zeroed:>8d} "
f"max_rel={max_rel:.4f} mean_rel={mean_rel:.4f} p99_rel={p99_rel:.4f}")
def check_l2_orientation():
"""Check 2: L2 weight/SF orientation.
The down_proj maps intermediate→hidden. In PyTorch, weight is (out, in) = (hidden, intermediate).
After NVFP4 packing: (hidden, intermediate//2).
After transpose for CUTLASS col-major B: (intermediate//2, hidden).
The CUTLASS GEMM computes: D = alpha * A @ B where A is (M, K) and B is (K, N).
K = intermediate (contraction dim), N = hidden (output dim).
Packed B is (K_half, N) in memory (column-major for CUTLASS).
Question: is the transpose correct for the CUTLASS B layout?
"""
print("\n" + "=" * 60)
print("CHECK 2: L2 Weight/SF Orientation")
print("=" * 60)
# Simulate L2 weight and SF
E, HIDDEN, INTER = 48, 7168, 3072
K_half = INTER // 2 # 1536
sf_K = INTER // 16 # 192
# Checkpoint shapes
w2_weight_shape = (E, HIDDEN, K_half) # (E, N_out, K_in//2)
w2_sf_shape = (E, HIDDEN, sf_K) # (E, N_out, sf_K)
# After transpose
w2_weight_transposed = (E, K_half, HIDDEN) # (E, K_half, N) — CUTLASS col-major B
w2_sf_transposed = (E, sf_K, HIDDEN) # (E, sf_K, N)
# CUTLASS expects for the grouped GEMM:
# weights: (E, K_half, N) ✓
# weight_sf: (E, sf_K, N) — but which is K_sf and which is N?
# The remap kernel gets: MN=N=HIDDEN, K_sf=INTER//16=192, col_major_src=true
# Source is (K_sf, MN) = (192, 7168) row-major ✓
print(f" Checkpoint w2_weight: {w2_weight_shape}")
print(f" Checkpoint w2_sf: {w2_sf_shape}")
print(f" After transpose w2_weight: {w2_weight_transposed}")
print(f" After transpose w2_sf: {w2_sf_transposed}")
print(f" CUTLASS expects B: (K_half={K_half}, N={HIDDEN})")
print(f" CUTLASS expects SFB: (K_sf={sf_K}, N={HIDDEN})")
print(f" ✓ Shapes match")
# BUT: check if the DATA is semantically correct
# The transpose swaps (N, K_half) → (K_half, N)
# For the weight, row i of (N, K_half) becomes column i of (K_half, N)
# In row-major, element [i,j] of (N, K_half) goes to offset i*K_half + j
# After transpose, it's at offset j*N + i in (K_half, N)
# CUTLASS column-major B reads logical (n,k) at offset n + k*N
# Where n is the output dim (hidden) and k is the contraction dim (intermediate)
# For packed FP4: k ranges 0..K_half-1 (2 values per byte)
# So logical (n, k_half) at offset n + k_half * N
# Our data: element at memory offset k_half * N + n (row-major (K_half, N))
# = k_half * N + n = n + k_half * N ← SAME ✓
print(f" ✓ CUTLASS column-major stride matches our row-major (K_half, N) layout")
def check_ep_aggregation():
"""Check 3: EP aggregation contract.
Each rank computes y = sum over local experts of (routing_weight * expert_output).
Then all-reduce sums across EP ranks.
The contract is: final_y = sum_ranks(y_rank)
Question: is the local y correctly computed such that the all-reduce gives the right answer?
"""
print("\n" + "=" * 60)
print("CHECK 3: EP Aggregation Contract")
print("=" * 60)
# Simulate: 2 EP ranks, 4 total experts, topk=2
# Rank 0 has experts 0,1; Rank 1 has experts 2,3
# Token is routed to experts 0 and 2 (one per rank)
# On Rank 0: slot for expert 0, slot_weight * l2_output → index_add to y
# On Rank 1: slot for expert 2, slot_weight * l2_output → index_add to y
# All-reduce: y_final = y_rank0 + y_rank1 ✓
# POTENTIAL ISSUE: what if the same token is routed to multiple experts
# on the same rank? index_add_ handles this correctly (sums in-place).
# POTENTIAL ISSUE: what if a token has NO experts on a rank?
# y stays at 0 for that token → correct, other ranks contribute.
# POTENTIAL ISSUE: is slot_weight correctly applied?
# In nvfp4_mega_moe_full:
# y.index_add_(0, slot_token, l2_slots * slot_weight.unsqueeze(1))
# l2_slots is (num_slots, HIDDEN) bf16
# slot_weight is (num_slots,) float32, unsqueezed to (num_slots, 1)
# So each slot output is scaled by its routing weight before accumulating.
# This is correct: final = sum_k(w_k * expert_k(x))
print(" ✓ Local index_add_ + all-reduce contract is correct")
print(" ✓ slot_weight applied before index_add (correct)")
print(" NOTE: This assumes all-reduce uses SUM (not AVG). Verify with torch.distributed.")
# Check the vllm code uses all_reduce (sum by default)
# torch.distributed.all_reduce defaults to ReduceOp.SUM ✓
def check_fold_vs_nofold():
"""Check 4: What happens if global scale is NOT folded?
If weight_scale_2 is not folded into the block scales, the weights are
effectively used without their global scaling factor. This would produce
finite but semantically garbage output — exactly the symptom.
"""
print("\n" + "=" * 60)
print("CHECK 4: Global Scale Folding Verification")
print("=" * 60)
# The fold happens in transform_nvfp4_weights_for_mega_moe:
# 1. sf_f32 = weight_scale.to(float32)
# 2. sf_f32 *= weight_scale_2 (global scale)
# 3. sf_out = sf_f32.clamp(0, 448).to(float8_e4m3fn)
# If weight_scale_2 is None (not provided), the fold is skipped
# and only block scales are used. This would be a bug.
# Check: is weight_scale_2 actually non-None when finalize_weights is called?
# From the code:
# transform_nvfp4_weights_for_mega_moe(
# ..., l1_weight_scale_2=self.w13_weight_scale_2.data.contiguous(), ...)
# self.w13_weight_scale_2 is initialized as nn.Parameter(torch.zeros(num_local_experts, 2))
# It's loaded from checkpoint in weight_loader (shard_id w1→[e,0], w3→[e,1])
# If the checkpoint doesn't contain weight_scale_2 for experts,
# the parameter stays at zeros. Folding with gs=0 → all scales become 0 → garbage.
print(" If weight_scale_2 is all zeros (not loaded from checkpoint):")
sf = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0])
gs_zero = 0.0
folded = (sf * gs_zero).clamp(0, 448).to(torch.float8_e4m3fn)
print(f" sf={sf.tolist()} * gs=0 → folded={folded.to(torch.float32).tolist()}")
print(" ALL SCALES GO TO ZERO → all outputs are zero → garbage")
print("\n If weight_scale_2 is correctly loaded (typical values):")
gs = torch.tensor([0.5, 1.0, 2.0, 5.0, 10.0])
for g in gs:
folded = (sf * g).clamp(0, 448).to(torch.float8_e4m3fn)
correct = sf * g
rel_err = ((folded.to(torch.float32) - correct).abs() / correct).mean()
print(f" gs={g:.1f}: mean_rel_err={rel_err:.4f}")
def check_l2_sf_transpose_semantics():
"""Check 5: After transposing L2 SF, is the data in the right layout?
The w2_weight_scale in checkpoint is (E, N, sf_K) = (E, hidden, inter//16).
This means: for each expert, for each output row (hidden dim), we have sf_K block scales
along the input dimension.
After transpose: (E, sf_K, N) = (E, inter//16, hidden).
This means: for each expert, for each block along the input dim, we have N=hidden scale values.
CUTLASS SFB is (K_sf, N) where K_sf is the contraction dim's scale groups.
K_sf = K // 16 = inter // 16. N = hidden.
The CUTLASS remap expects col_major_src=True, so it reads src[k_sf * N + m].
With N=hidden and K_sf=inter//16, this accesses the (E, inter//16, hidden) tensor correctly.
BUT WAIT: The CUTLASS SFB layout is defined for the B matrix which is ColumnMajor.
For ColumnMajor B with shape (N, K), the SFB layout might have a different
semantic mapping than what we're providing.
Let me check: does CUTLASS SFB index by (N_idx, K_sf_idx) or (K_sf_idx, N_idx)?
"""
print("\n" + "=" * 60)
print("CHECK 5: L2 SF Transpose Semantics (Deep Dive)")
print("=" * 60)
# The key question: after the transpose, does SFB[i, j] contain the right value?
#
# Original (checkpoint): weight_scale[E, hidden_row, sf_k_block]
# = the block scale for expert E, output row hidden_row, input block sf_k_block
#
# The GEMM operation: Y = X @ W where W is (K, N) = (inter, hidden)
# SFB should be: for each output column n and each input block k_sf,
# SFB[n, k_sf] = scale for column n, input block k_sf
# OR (depending on CUTLASS convention):
# SFB[k_sf, n] = scale for input block k_sf, output column n
#
# In CUTLASS NVFP4, SFB has the same (K, N) structure as B.
# B is ColumnMajor (N, K), so B[n, k] is at memory n + k * N.
# SFB should follow the same (N, K_sf) → ColumnMajor → (K_sf, N) row-major in memory.
#
# Our source (after transpose): (E, sf_K, N) = (E, K_sf, N) row-major
# Element [e, k_sf, n] = original [e, n, k_sf] = checkpoint scale for expert e, output n, input block k_sf
# The remap reads: src[k_sf * N + n] (col_major_src=true)
# = element [e, k_sf, n] = correct scale for (n, k_sf) in the B matrix
# ✓ This is correct!
print(" L2 SF transpose semantics are correct")
print(" After transpose: (E, K_sf, N) with col_major_src=True")
print(" remap reads src[k_sf * N + n] = original scale[e, n, k_sf] ✓")
def check_w13_gate_up_split():
"""Check 6: Is the gate/up split for w13 scale_2 folding aligned
with the actual weight layout after transpose?
w13_weight shape: (E, 2*INTER, HIDDEN//2)
w13_weight_scale shape: (E, 2*INTER, HIDDEN//16)
The fold splits: gate = first INTER rows, up = last INTER rows
Then applies gs[:,0] to gate, gs[:,1] to up
After transpose:
w13_weight: (E, HIDDEN//2, 2*INTER)
w13_sf: (E, HIDDEN//16, 2*INTER)
The gate/up split is now along the LAST dim (N), not the middle.
But the fold happens BEFORE the transpose, so the split is correct.
After transpose, the gate portion is columns 0..INTER-1 and up is INTER..2*INTER-1.
This is still semantically correct for the CUTLASS GEMM.
"""
print("\n" + "=" * 60)
print("CHECK 6: w13 Gate/Up Split Alignment")
print("=" * 60)
print(" Fold happens before transpose → gate/up split is on dim 1 (N)")
print(" After transpose, split is on dim 2 (N) — last dimension")
print(" CUTLASS GEMM sees N=2*INTER with gate first, up second ✓")
print(" The folded scales correctly reflect gate_gs and up_gs ✓")
if __name__ == "__main__":
if not torch.cuda.is_available():
print("WARNING: No CUDA — some checks will be approximate")
check_fold_precision()
check_l2_orientation()
check_ep_aggregation()
check_fold_vs_nofold()
check_l2_sf_transpose_semantics()
check_w13_gate_up_split()
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print("""
Most likely suspects for "finite but garbage" output:
1. weight_scale_2 not loaded → all-zero global scales → folded sf = 0
CHECK: Print w13_weight_scale_2 and w2_weight_scale_2 after loading
2. Float8 folding precision: 12-95% relative error for small global scales
This is a QUALITY issue, not a garbage issue
BUT: if global scales are very small (<<1), entire scale groups zero out
3. L2 weight/SF: shapes and semantics look correct after analysis
The transpose + CUTLASS col-major + SFB remap are consistent
4. EP aggregation: contract looks correct (local sum + all_reduce)
ACTION ITEMS:
a) Run the model with debug prints showing weight_scale_2 values
b) Check if any folded scales clamp to 0 or 448 (precision ceiling)
c) Compare folded sf values against reference (unfolded) computation
d) Test with a single expert to isolate EP issues
""")