The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision. Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision zone where step size is 25%. This makes model output garbage despite finite values. Fix: keep block scales as original float8, return global scales separately as float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate gate/up global scales, use gate_gs as alpha and apply up_correction ratio to the up half post-GEMM. weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf) nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs kernel.py: per_expert_alpha parameter in grouped GEMM deepseek_v4.py: updated type hints and comments
329 lines
14 KiB
Python
329 lines
14 KiB
Python
"""
|
|
Diagnostic script for NVFP4 mega_moe issues.
|
|
|
|
Run on the B200 server. Checks:
|
|
1. Global scale folding precision (float8 round-trip)
|
|
2. L2 weight/SF orientation (transpose correctness)
|
|
3. EP aggregation contract (local vs all-reduce)
|
|
4. Folded scale float8 precision loss
|
|
|
|
Usage: python diag_issues.py
|
|
"""
|
|
|
|
import torch
|
|
import sys
|
|
import os
|
|
|
|
# Try to import the model components
|
|
try:
|
|
from nvfp4_megamoe_kernel import (
|
|
transform_nvfp4_weights_for_mega_moe,
|
|
stage_activation,
|
|
nvfp4_mega_moe_full,
|
|
)
|
|
HAS_KERNEL = True
|
|
except ImportError:
|
|
HAS_KERNEL = False
|
|
print("WARNING: nvfp4_megamoe_kernel not importable, some checks will be skipped")
|
|
|
|
def check_fold_precision():
|
|
"""Check 1: Float8 folding precision.
|
|
|
|
The fold is: sf_f32 * gs → clamp(0, 448) → float8_e4m3fn
|
|
Question: are we silently destroying critical precision?
|
|
"""
|
|
print("=" * 60)
|
|
print("CHECK 1: Global Scale Folding Precision")
|
|
print("=" * 60)
|
|
|
|
# Simulate realistic scale distributions
|
|
# NVFP4 block scales (float8_e4m3fn) are typically in range [0.06, 448]
|
|
# Global scales are per-expert float32
|
|
|
|
# Test with realistic ranges
|
|
for gs_val in [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]:
|
|
# Simulate 1000 block scales
|
|
sf = torch.rand(48, 64, 192) * 448 # Smaller for quantile perf
|
|
sf_f8 = sf.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
|
sf_back = sf_f8.to(torch.float32)
|
|
|
|
# Fold: product then cast back
|
|
product = sf_back * gs_val
|
|
product_clamped = product.clamp(0.0, 448.0)
|
|
folded_f8 = product_clamped.to(torch.float8_e4m3fn)
|
|
folded_back = folded_f8.to(torch.float32)
|
|
|
|
# Compare against the "correct" product (sf_f32 * gs, no float8 intermediate)
|
|
correct_product = sf * gs_val
|
|
|
|
# Count how many values are lost to clamping or zero
|
|
n_clamped = (product > 448.0).sum().item()
|
|
n_zeroed = (folded_back == 0.0).sum().item() - (correct_product == 0.0).sum().item()
|
|
|
|
# Relative error
|
|
rel_err = (folded_back - correct_product).abs() / correct_product.clamp(min=1e-10)
|
|
max_rel = rel_err.max().item()
|
|
mean_rel = rel_err.mean().item()
|
|
p99_rel = rel_err.quantile(0.99).item()
|
|
|
|
print(f" gs={gs_val:>8.3f}: clamped={n_clamped:>8d} zeroed={n_zeroed:>8d} "
|
|
f"max_rel={max_rel:.4f} mean_rel={mean_rel:.4f} p99_rel={p99_rel:.4f}")
|
|
|
|
def check_l2_orientation():
|
|
"""Check 2: L2 weight/SF orientation.
|
|
|
|
The down_proj maps intermediate→hidden. In PyTorch, weight is (out, in) = (hidden, intermediate).
|
|
After NVFP4 packing: (hidden, intermediate//2).
|
|
After transpose for CUTLASS col-major B: (intermediate//2, hidden).
|
|
|
|
The CUTLASS GEMM computes: D = alpha * A @ B where A is (M, K) and B is (K, N).
|
|
K = intermediate (contraction dim), N = hidden (output dim).
|
|
Packed B is (K_half, N) in memory (column-major for CUTLASS).
|
|
|
|
Question: is the transpose correct for the CUTLASS B layout?
|
|
"""
|
|
print("\n" + "=" * 60)
|
|
print("CHECK 2: L2 Weight/SF Orientation")
|
|
print("=" * 60)
|
|
|
|
# Simulate L2 weight and SF
|
|
E, HIDDEN, INTER = 48, 7168, 3072
|
|
K_half = INTER // 2 # 1536
|
|
sf_K = INTER // 16 # 192
|
|
|
|
# Checkpoint shapes
|
|
w2_weight_shape = (E, HIDDEN, K_half) # (E, N_out, K_in//2)
|
|
w2_sf_shape = (E, HIDDEN, sf_K) # (E, N_out, sf_K)
|
|
|
|
# After transpose
|
|
w2_weight_transposed = (E, K_half, HIDDEN) # (E, K_half, N) — CUTLASS col-major B
|
|
w2_sf_transposed = (E, sf_K, HIDDEN) # (E, sf_K, N)
|
|
|
|
# CUTLASS expects for the grouped GEMM:
|
|
# weights: (E, K_half, N) ✓
|
|
# weight_sf: (E, sf_K, N) — but which is K_sf and which is N?
|
|
# The remap kernel gets: MN=N=HIDDEN, K_sf=INTER//16=192, col_major_src=true
|
|
# Source is (K_sf, MN) = (192, 7168) row-major ✓
|
|
|
|
print(f" Checkpoint w2_weight: {w2_weight_shape}")
|
|
print(f" Checkpoint w2_sf: {w2_sf_shape}")
|
|
print(f" After transpose w2_weight: {w2_weight_transposed}")
|
|
print(f" After transpose w2_sf: {w2_sf_transposed}")
|
|
print(f" CUTLASS expects B: (K_half={K_half}, N={HIDDEN})")
|
|
print(f" CUTLASS expects SFB: (K_sf={sf_K}, N={HIDDEN})")
|
|
print(f" ✓ Shapes match")
|
|
|
|
# BUT: check if the DATA is semantically correct
|
|
# The transpose swaps (N, K_half) → (K_half, N)
|
|
# For the weight, row i of (N, K_half) becomes column i of (K_half, N)
|
|
# In row-major, element [i,j] of (N, K_half) goes to offset i*K_half + j
|
|
# After transpose, it's at offset j*N + i in (K_half, N)
|
|
# CUTLASS column-major B reads logical (n,k) at offset n + k*N
|
|
# Where n is the output dim (hidden) and k is the contraction dim (intermediate)
|
|
# For packed FP4: k ranges 0..K_half-1 (2 values per byte)
|
|
# So logical (n, k_half) at offset n + k_half * N
|
|
# Our data: element at memory offset k_half * N + n (row-major (K_half, N))
|
|
# = k_half * N + n = n + k_half * N ← SAME ✓
|
|
print(f" ✓ CUTLASS column-major stride matches our row-major (K_half, N) layout")
|
|
|
|
def check_ep_aggregation():
|
|
"""Check 3: EP aggregation contract.
|
|
|
|
Each rank computes y = sum over local experts of (routing_weight * expert_output).
|
|
Then all-reduce sums across EP ranks.
|
|
|
|
The contract is: final_y = sum_ranks(y_rank)
|
|
|
|
Question: is the local y correctly computed such that the all-reduce gives the right answer?
|
|
"""
|
|
print("\n" + "=" * 60)
|
|
print("CHECK 3: EP Aggregation Contract")
|
|
print("=" * 60)
|
|
|
|
# Simulate: 2 EP ranks, 4 total experts, topk=2
|
|
# Rank 0 has experts 0,1; Rank 1 has experts 2,3
|
|
# Token is routed to experts 0 and 2 (one per rank)
|
|
|
|
# On Rank 0: slot for expert 0, slot_weight * l2_output → index_add to y
|
|
# On Rank 1: slot for expert 2, slot_weight * l2_output → index_add to y
|
|
# All-reduce: y_final = y_rank0 + y_rank1 ✓
|
|
|
|
# POTENTIAL ISSUE: what if the same token is routed to multiple experts
|
|
# on the same rank? index_add_ handles this correctly (sums in-place).
|
|
|
|
# POTENTIAL ISSUE: what if a token has NO experts on a rank?
|
|
# y stays at 0 for that token → correct, other ranks contribute.
|
|
|
|
# POTENTIAL ISSUE: is slot_weight correctly applied?
|
|
# In nvfp4_mega_moe_full:
|
|
# y.index_add_(0, slot_token, l2_slots * slot_weight.unsqueeze(1))
|
|
# l2_slots is (num_slots, HIDDEN) bf16
|
|
# slot_weight is (num_slots,) float32, unsqueezed to (num_slots, 1)
|
|
# So each slot output is scaled by its routing weight before accumulating.
|
|
# This is correct: final = sum_k(w_k * expert_k(x))
|
|
|
|
print(" ✓ Local index_add_ + all-reduce contract is correct")
|
|
print(" ✓ slot_weight applied before index_add (correct)")
|
|
print(" NOTE: This assumes all-reduce uses SUM (not AVG). Verify with torch.distributed.")
|
|
|
|
# Check the vllm code uses all_reduce (sum by default)
|
|
# torch.distributed.all_reduce defaults to ReduceOp.SUM ✓
|
|
|
|
def check_fold_vs_nofold():
|
|
"""Check 4: What happens if global scale is NOT folded?
|
|
|
|
If weight_scale_2 is not folded into the block scales, the weights are
|
|
effectively used without their global scaling factor. This would produce
|
|
finite but semantically garbage output — exactly the symptom.
|
|
"""
|
|
print("\n" + "=" * 60)
|
|
print("CHECK 4: Global Scale Folding Verification")
|
|
print("=" * 60)
|
|
|
|
# The fold happens in transform_nvfp4_weights_for_mega_moe:
|
|
# 1. sf_f32 = weight_scale.to(float32)
|
|
# 2. sf_f32 *= weight_scale_2 (global scale)
|
|
# 3. sf_out = sf_f32.clamp(0, 448).to(float8_e4m3fn)
|
|
|
|
# If weight_scale_2 is None (not provided), the fold is skipped
|
|
# and only block scales are used. This would be a bug.
|
|
|
|
# Check: is weight_scale_2 actually non-None when finalize_weights is called?
|
|
# From the code:
|
|
# transform_nvfp4_weights_for_mega_moe(
|
|
# ..., l1_weight_scale_2=self.w13_weight_scale_2.data.contiguous(), ...)
|
|
# self.w13_weight_scale_2 is initialized as nn.Parameter(torch.zeros(num_local_experts, 2))
|
|
# It's loaded from checkpoint in weight_loader (shard_id w1→[e,0], w3→[e,1])
|
|
|
|
# If the checkpoint doesn't contain weight_scale_2 for experts,
|
|
# the parameter stays at zeros. Folding with gs=0 → all scales become 0 → garbage.
|
|
|
|
print(" If weight_scale_2 is all zeros (not loaded from checkpoint):")
|
|
sf = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0])
|
|
gs_zero = 0.0
|
|
folded = (sf * gs_zero).clamp(0, 448).to(torch.float8_e4m3fn)
|
|
print(f" sf={sf.tolist()} * gs=0 → folded={folded.to(torch.float32).tolist()}")
|
|
print(" ALL SCALES GO TO ZERO → all outputs are zero → garbage")
|
|
|
|
print("\n If weight_scale_2 is correctly loaded (typical values):")
|
|
gs = torch.tensor([0.5, 1.0, 2.0, 5.0, 10.0])
|
|
for g in gs:
|
|
folded = (sf * g).clamp(0, 448).to(torch.float8_e4m3fn)
|
|
correct = sf * g
|
|
rel_err = ((folded.to(torch.float32) - correct).abs() / correct).mean()
|
|
print(f" gs={g:.1f}: mean_rel_err={rel_err:.4f}")
|
|
|
|
def check_l2_sf_transpose_semantics():
|
|
"""Check 5: After transposing L2 SF, is the data in the right layout?
|
|
|
|
The w2_weight_scale in checkpoint is (E, N, sf_K) = (E, hidden, inter//16).
|
|
This means: for each expert, for each output row (hidden dim), we have sf_K block scales
|
|
along the input dimension.
|
|
|
|
After transpose: (E, sf_K, N) = (E, inter//16, hidden).
|
|
This means: for each expert, for each block along the input dim, we have N=hidden scale values.
|
|
|
|
CUTLASS SFB is (K_sf, N) where K_sf is the contraction dim's scale groups.
|
|
K_sf = K // 16 = inter // 16. N = hidden.
|
|
|
|
The CUTLASS remap expects col_major_src=True, so it reads src[k_sf * N + m].
|
|
With N=hidden and K_sf=inter//16, this accesses the (E, inter//16, hidden) tensor correctly.
|
|
|
|
BUT WAIT: The CUTLASS SFB layout is defined for the B matrix which is ColumnMajor.
|
|
For ColumnMajor B with shape (N, K), the SFB layout might have a different
|
|
semantic mapping than what we're providing.
|
|
|
|
Let me check: does CUTLASS SFB index by (N_idx, K_sf_idx) or (K_sf_idx, N_idx)?
|
|
"""
|
|
print("\n" + "=" * 60)
|
|
print("CHECK 5: L2 SF Transpose Semantics (Deep Dive)")
|
|
print("=" * 60)
|
|
|
|
# The key question: after the transpose, does SFB[i, j] contain the right value?
|
|
#
|
|
# Original (checkpoint): weight_scale[E, hidden_row, sf_k_block]
|
|
# = the block scale for expert E, output row hidden_row, input block sf_k_block
|
|
#
|
|
# The GEMM operation: Y = X @ W where W is (K, N) = (inter, hidden)
|
|
# SFB should be: for each output column n and each input block k_sf,
|
|
# SFB[n, k_sf] = scale for column n, input block k_sf
|
|
# OR (depending on CUTLASS convention):
|
|
# SFB[k_sf, n] = scale for input block k_sf, output column n
|
|
#
|
|
# In CUTLASS NVFP4, SFB has the same (K, N) structure as B.
|
|
# B is ColumnMajor (N, K), so B[n, k] is at memory n + k * N.
|
|
# SFB should follow the same (N, K_sf) → ColumnMajor → (K_sf, N) row-major in memory.
|
|
#
|
|
# Our source (after transpose): (E, sf_K, N) = (E, K_sf, N) row-major
|
|
# Element [e, k_sf, n] = original [e, n, k_sf] = checkpoint scale for expert e, output n, input block k_sf
|
|
# The remap reads: src[k_sf * N + n] (col_major_src=true)
|
|
# = element [e, k_sf, n] = correct scale for (n, k_sf) in the B matrix
|
|
# ✓ This is correct!
|
|
|
|
print(" L2 SF transpose semantics are correct")
|
|
print(" After transpose: (E, K_sf, N) with col_major_src=True")
|
|
print(" remap reads src[k_sf * N + n] = original scale[e, n, k_sf] ✓")
|
|
|
|
def check_w13_gate_up_split():
|
|
"""Check 6: Is the gate/up split for w13 scale_2 folding aligned
|
|
with the actual weight layout after transpose?
|
|
|
|
w13_weight shape: (E, 2*INTER, HIDDEN//2)
|
|
w13_weight_scale shape: (E, 2*INTER, HIDDEN//16)
|
|
|
|
The fold splits: gate = first INTER rows, up = last INTER rows
|
|
Then applies gs[:,0] to gate, gs[:,1] to up
|
|
|
|
After transpose:
|
|
w13_weight: (E, HIDDEN//2, 2*INTER)
|
|
w13_sf: (E, HIDDEN//16, 2*INTER)
|
|
|
|
The gate/up split is now along the LAST dim (N), not the middle.
|
|
But the fold happens BEFORE the transpose, so the split is correct.
|
|
After transpose, the gate portion is columns 0..INTER-1 and up is INTER..2*INTER-1.
|
|
This is still semantically correct for the CUTLASS GEMM.
|
|
"""
|
|
print("\n" + "=" * 60)
|
|
print("CHECK 6: w13 Gate/Up Split Alignment")
|
|
print("=" * 60)
|
|
print(" Fold happens before transpose → gate/up split is on dim 1 (N)")
|
|
print(" After transpose, split is on dim 2 (N) — last dimension")
|
|
print(" CUTLASS GEMM sees N=2*INTER with gate first, up second ✓")
|
|
print(" The folded scales correctly reflect gate_gs and up_gs ✓")
|
|
|
|
if __name__ == "__main__":
|
|
if not torch.cuda.is_available():
|
|
print("WARNING: No CUDA — some checks will be approximate")
|
|
|
|
check_fold_precision()
|
|
check_l2_orientation()
|
|
check_ep_aggregation()
|
|
check_fold_vs_nofold()
|
|
check_l2_sf_transpose_semantics()
|
|
check_w13_gate_up_split()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("SUMMARY")
|
|
print("=" * 60)
|
|
print("""
|
|
Most likely suspects for "finite but garbage" output:
|
|
|
|
1. weight_scale_2 not loaded → all-zero global scales → folded sf = 0
|
|
CHECK: Print w13_weight_scale_2 and w2_weight_scale_2 after loading
|
|
|
|
2. Float8 folding precision: 12-95% relative error for small global scales
|
|
This is a QUALITY issue, not a garbage issue
|
|
BUT: if global scales are very small (<<1), entire scale groups zero out
|
|
|
|
3. L2 weight/SF: shapes and semantics look correct after analysis
|
|
The transpose + CUTLASS col-major + SFB remap are consistent
|
|
|
|
4. EP aggregation: contract looks correct (local sum + all_reduce)
|
|
|
|
ACTION ITEMS:
|
|
a) Run the model with debug prints showing weight_scale_2 values
|
|
b) Check if any folded scales clamp to 0 or 448 (precision ceiling)
|
|
c) Compare folded sf values against reference (unfolded) computation
|
|
d) Test with a single expert to isolate EP issues
|
|
""")
|