From fd59222fc044efc032d8faf781a1b14bf2eb337e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 12:42:53 +0000 Subject: [PATCH] fix: stop folding global scale into float8 block scales The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision. Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision zone where step size is 25%. This makes model output garbage despite finite values. Fix: keep block scales as original float8, return global scales separately as float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate gate/up global scales, use gate_gs as alpha and apply up_correction ratio to the up half post-GEMM. weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf) nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs kernel.py: per_expert_alpha parameter in grouped GEMM deepseek_v4.py: updated type hints and comments --- diag_b200.py | 279 +++++++++++++++ diag_fold.py | 66 ++++ diag_fold_real.py | 96 +++++ diag_issues.py | 328 ++++++++++++++++++ diag_keys.py | 39 +++ .../cutlass_nvfp4_gemm/kernel.py | 17 +- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 107 ++++-- src/nvfp4_megamoe_kernel/weight_transform.py | 164 +++------ vllm/patches/deepseek_v4.py | 16 +- 9 files changed, 955 insertions(+), 157 deletions(-) create mode 100644 diag_b200.py create mode 100644 diag_fold.py create mode 100644 diag_fold_real.py create mode 100644 diag_issues.py create mode 100644 diag_keys.py diff --git a/diag_b200.py b/diag_b200.py new file mode 100644 index 00000000..3cf85b16 --- /dev/null +++ b/diag_b200.py @@ -0,0 +1,279 @@ +""" +NVFP4 MegaMoE Diagnostic — B200 +Checks: +1. weight_scale_2 values (are they nonzero / loaded correctly?) +2. Folded scale ranges (clamp/precision loss) +3. L2 weight/SF orientation sanity +4. Dequant reference vs CUTLASS output comparison +5. Single-expert, single-layer test +""" +import torch +import sys +import os +import json +from pathlib import Path + +MODEL_PATH = "/model" # inside the container + +def inspect_checkpoint_scales(): + """Check raw checkpoint weight_scale_2 values.""" + from safetensors import safe_open + import glob + + print("=" * 60) + print("CHECK 1: Checkpoint weight_scale_2 Values") + print("=" * 60) + + # Find checkpoint files + ckpt_files = sorted(glob.glob(os.path.join(MODEL_PATH, "*.safetensors"))) + print(f"Found {len(ckpt_files)} safetensors files") + + # Look for expert weight_scale_2 params + w13_gs_found = 0 + w2_gs_found = 0 + w13_gs_values = {} + w2_gs_values = {} + + for f in ckpt_files: + with safe_open(f, framework="pt") as st: + for key in st.keys(): + if "weight_scale_2" in key and ("experts" in key or "ffn" in key): + val = st.get_tensor(key) + if "w13" in key or "gate_up" in key or "w1" in key or "w3" in key: + w13_gs_found += 1 + if w13_gs_found <= 3: + w13_gs_values[key] = {"shape": list(val.shape), "dtype": str(val.dtype), + "min": val.float().min().item(), "max": val.float().max().item(), + "mean": val.float().mean().item()} + elif "w2" in key or "down" in key: + w2_gs_found += 1 + if w2_gs_found <= 3: + w2_gs_values[key] = {"shape": list(val.shape), "dtype": str(val.dtype), + "min": val.float().min().item(), "max": val.float().max().item(), + "mean": val.float().mean().item()} + + print(f"w13 weight_scale_2 entries: {w13_gs_found}") + print(f"w2 weight_scale_2 entries: {w2_gs_found}") + for k, v in w13_gs_values.items(): + print(f" {k}: {v}") + for k, v in w2_gs_values.items(): + print(f" {k}: {v}") + + return w13_gs_found > 0 and w2_gs_found > 0 + + +def inspect_loaded_model(): + """Check the model's weight_scale_2 after loading (before finalize_weights).""" + print("\n" + "=" * 60) + print("CHECK 2: Model weight_scale_2 After Loading") + print("=" * 60) + + # We need to load the model and inspect before finalize_weights nukes the params + # The vLLM server is already running, so let's check the live model + # Actually, let's load a fresh model instance for inspection + + # Simpler approach: just check the checkpoint directly for scale_2 + # The real check is whether finalize_weights gets called with nonzero scale_2 + print(" (Checkpoint inspection is more reliable — see CHECK 1)") + print(" The [SF-DEBUG] prints from weight_transform.py should also show this") + + +def check_fold_precision_real(): + """Check float8 folding precision with real checkpoint scales.""" + print("\n" + "=" * 60) + print("CHECK 3: Float8 Folding Precision (Real Scales)") + print("=" * 60) + + from safetensors import safe_open + import glob + + ckpt_files = sorted(glob.glob(os.path.join(MODEL_PATH, "*.safetensors"))) + + # Find one layer's expert scales + for f in ckpt_files: + with safe_open(f, framework="pt") as st: + keys = list(st.keys()) + # Find w2 weight_scale and weight_scale_2 for layer 0 + w2_sf_key = None + w2_gs_key = None + w13_sf_key = None + w13_gs_key = None + + for k in keys: + if "layers.0" in k: + if "w2" in k and k.endswith("weight_scale") and "scale_2" not in k: + w2_sf_key = k + elif "w2" in k and "weight_scale_2" in k: + w2_gs_key = k + elif ("w13" in k or "gate_up" in k) and k.endswith("weight_scale") and "scale_2" not in k: + w13_sf_key = k + elif ("w13" in k or "gate_up" in k) and "weight_scale_2" in k: + w13_gs_key = k + + if w2_sf_key and w2_gs_key: + w2_sf = st.get_tensor(w2_sf_key) + w2_gs = st.get_tensor(w2_gs_key) + print(f" L2 block scale: shape={list(w2_sf.shape)} dtype={w2_sf.dtype} " + f"range=[{w2_sf.float().min():.4e}, {w2_sf.float().max():.4e}]") + print(f" L2 global scale: shape={list(w2_gs.shape)} dtype={w2_gs.dtype} " + f"range=[{w2_gs.float().min():.4e}, {w2_gs.float().max():.4e}]") + + # Fold and check precision + sf_f32 = w2_sf.float() + gs_f32 = w2_gs.float() + + # Reshape gs for broadcast + while gs_f32.dim() < sf_f32.dim(): + gs_f32 = gs_f32.unsqueeze(-1) + + product = sf_f32 * gs_f32 + product_clamped = product.clamp(0.0, 448.0) + folded_f8 = product_clamped.to(torch.float8_e4m3fn) + folded_back = folded_f8.float() + + # Stats + n_clamped = (product > 448.0).sum().item() + n_total = product.numel() + n_zeroed = (folded_back == 0.0).sum().item() + + rel_err = (folded_back - product).abs() / product.clamp(min=1e-10) + print(f"\n L2 Fold results:") + print(f" Clamped to 448: {n_clamped}/{n_total} ({100*n_clamped/n_total:.1f}%)") + print(f" Zeroed (subnormal): {n_zeroed}/{n_total} ({100*n_zeroed/n_total:.1f}%)") + print(f" Rel error: max={rel_err.max():.4f} mean={rel_err.mean():.4f} p99={rel_err.quantile(0.99):.4f}") + + # Show distribution of folded values + fb_hist = torch.histc(folded_back, bins=10, min=0, max=448) + print(f" Folded value histogram (0-448, 10 bins): {fb_hist.int().tolist()}") + + # CRITICAL CHECK: is the product range within float8? + print(f" Product range: [{product.min():.4e}, {product.max():.4e}]") + if n_clamped > 0: + print(f" ⚠️ {n_clamped} values clamped — this IS precision loss!") + + if w13_sf_key and w13_gs_key: + w13_sf = st.get_tensor(w13_sf_key) + w13_gs = st.get_tensor(w13_gs_key) + print(f"\n L1 block scale: shape={list(w13_sf.shape)} dtype={w13_sf.dtype} " + f"range=[{w13_sf.float().min():.4e}, {w13_sf.float().max():.4e}]") + print(f" L1 global scale: shape={list(w13_gs.shape)} dtype={w13_gs.dtype} " + f"range=[{w13_gs.float().min():.4e}, {w13_gs.float().max():.4e}]") + + break # Just check one file that has layer 0 + + +def check_l2_weight_semantics(): + """Verify L2 weight layout by dequantizing and checking against reference.""" + print("\n" + "=" * 60) + print("CHECK 4: L2 Weight Dequantization Sanity") + print("=" * 60) + + from safetensors import safe_open + import glob + + ckpt_files = sorted(glob.glob(os.path.join(MODEL_PATH, "*.safetensors"))) + + for f in ckpt_files: + with safe_open(f, framework="pt") as st: + keys = list(st.keys()) + # Find layer 0 w2 weight, weight_scale, weight_scale_2 + w2_w = w2_sf = w2_gs = None + for k in keys: + if "layers.0" in k: + if "w2" in k and k.endswith(".weight") and "scale" not in k: + w2_w = st.get_tensor(k) + elif "w2" in k and "weight_scale" == k.split(".")[-1]: + w2_sf = st.get_tensor(k) + elif "w2" in k and "weight_scale_2" in k: + w2_gs = st.get_tensor(k) + + if w2_w is not None and w2_sf is not None and w2_gs is not None: + print(f" w2_weight: shape={list(w2_w.shape)} dtype={w2_w.dtype}") + print(f" w2_weight_scale: shape={list(w2_sf.shape)} dtype={w2_sf.dtype}") + print(f" w2_weight_scale_2: shape={list(w2_gs.shape)} dtype={w2_gs.dtype}") + + # Dequantize a small patch + # w2 is down_proj: (hidden, intermediate) in BF16, or (hidden, inter//2) uint8 for NVFP4 + if w2_w.dtype == torch.uint8: + # Unpack E2M1 + FP4_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, + -0, -0.5, -1, -1.5, -2, -3, -4, -6], + dtype=torch.float32, device=w2_w.device) + lower = FP4_LUT[(w2_w[:4, :8] & 0x0F).long()] + upper = FP4_LUT[((w2_w[:4, :8] >> 4) & 0x0F).long()] + unpacked = torch.empty(4, 16, dtype=torch.float32) + unpacked[:, 0::2] = lower + unpacked[:, 1::2] = upper + + # Apply scales + sf_slice = w2_sf[:4, :1].float() # (4, 1) + gs = w2_gs.float() + print(f" Dequantized w2[:4, :16] with sf[:4,:1]={sf_slice.flatten().tolist()}") + print(f" global_scale_2 = {gs.item() if gs.numel() == 1 else gs[:4].flatten().tolist()}") + dequant = unpacked * sf_slice * gs.float() + print(f" Dequantized range: [{dequant.min():.4f}, {dequant.max():.4f}]") + print(f" Dequantized[:2, :8]: {dequant[:2, :8].tolist()}") + else: + print(f" w2_weight is {w2_w.dtype}, not uint8 — may be BF16 checkpoint") + print(f" w2[:4, :8] = {w2_w[:4, :8].tolist()}") + break + + +def check_ep_reduce_contract(): + """Verify the EP all-reduce contract with a synthetic test.""" + print("\n" + "=" * 60) + print("CHECK 5: EP Reduce Contract (Synthetic)") + print("=" * 60) + + # Simulate 2 ranks + M, HIDDEN = 4, 8 + # Rank 0: experts 0,1 — tokens routed to expert 0 (slot_weight=0.7) and 1 (slot_weight=0.3) + y0 = torch.zeros(M, HIDDEN, dtype=torch.bfloat16) + slot_token_0 = torch.tensor([0, 0, 1, 2, 3]) # which tokens + slot_weight_0 = torch.tensor([0.7, 0.3, 0.5, 0.6, 0.4], dtype=torch.bfloat16) + l2_slots_0 = torch.randn(5, HIDDEN, dtype=torch.bfloat16) + y0.index_add_(0, slot_token_0, l2_slots_0 * slot_weight_0.unsqueeze(1)) + + # Rank 1: experts 2,3 — token 0 also routed to expert 2 + y1 = torch.zeros(M, HIDDEN, dtype=torch.bfloat16) + slot_token_1 = torch.tensor([0, 1]) + slot_weight_1 = torch.tensor([0.2, 0.5], dtype=torch.bfloat16) + l2_slots_1 = torch.randn(2, HIDDEN, dtype=torch.bfloat16) + y1.index_add_(0, slot_token_1, l2_slots_1 * slot_weight_1.unsqueeze(1)) + + # All-reduce (sum) + y_final = y0 + y1 # simulated all-reduce + + # Verify: token 0 should have contributions from rank0 (experts 0,1) and rank1 (expert 2) + expected_0 = (0.7 * l2_slots_0[0] + 0.3 * l2_slots_0[1] + 0.2 * l2_slots_1[0]).bfloat16() + actual_0 = y_final[0].bfloat16() + diff = (expected_0 - actual_0).abs().max().item() + print(f" Token 0: expected vs actual diff = {diff:.6f} ✓" if diff < 0.01 else f" Token 0: MISMATCH diff = {diff}") + print(f" EP reduce contract is correct — sum of partial rank outputs gives full result") + + +if __name__ == "__main__": + print("NVFP4 MegaMoE Diagnostic — B200") + print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}") + print(f"GPUs: {torch.cuda.device_count()}") + print() + + try: + inspect_checkpoint_scales() + except Exception as e: + print(f"CHECK 1 FAILED: {e}") + + try: + check_fold_precision_real() + except Exception as e: + print(f"CHECK 3 FAILED: {e}") + + try: + check_l2_weight_semantics() + except Exception as e: + print(f"CHECK 4 FAILED: {e}") + + try: + check_ep_reduce_contract() + except Exception as e: + print(f"CHECK 5 FAILED: {e}") diff --git a/diag_fold.py b/diag_fold.py new file mode 100644 index 00000000..532d1316 --- /dev/null +++ b/diag_fold.py @@ -0,0 +1,66 @@ +""" +Diagnostic: Check global scale folding precision for NVFP4 weights. + +The fold is: sf_f32 * gs → clamp(0, 448) → float8_e4m3fn +Question: how much precision is lost in the float8 round-trip? +""" +import torch + +# Simulate typical NVFP4 scale distributions +# block_scale (float8_e4m3fn) range: roughly 0.06 to 448 +# global_scale (float32) range: varies per expert + +# Test 1: If global_scale >> 1, product can exceed 448 → clamp → loss +# Test 2: If global_scale << 1, product can go subnormal → loss +# Test 3: Quantization error from 3-bit mantissa + +# Simulate a range of scale values +block_scales = torch.tensor([0.0625, 0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 448.0], dtype=torch.float32) +global_scales = torch.tensor([0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0], dtype=torch.float32) + +print("=== Float8 Folding Precision Analysis ===\n") +print(f"block_scales: {block_scales.tolist()}") +print(f"global_scales: {global_scales.tolist()}\n") + +total_clamped = 0 +total_subnormal = 0 +max_rel_error = 0.0 + +for gs in global_scales: + products = block_scales * gs + clamped = products.clamp(0.0, 448.0) + folded_f8 = clamped.to(torch.float8_e4m3fn) + roundtrip = folded_f8.to(torch.float32) + + n_clamped = (products > 448.0).sum().item() + n_subnormal = (roundtrip > 0).logical_and(roundtrip < 0.0625).sum().item() # rough check + + rel_errors = torch.where(roundtrip > 0, (roundtrip - clamped).abs() / clamped.clamp(min=1e-10), torch.zeros_like(clamped)) + max_err = rel_errors.max().item() + + total_clamped += n_clamped + total_subnormal += n_subnormal + max_rel_error = max(max_rel_error, max_err) + + if n_clamped > 0 or max_err > 0.05: + print(f"gs={gs:.3f}: {n_clamped} clamped, max_rel_err={max_err:.4f}") + for i, (p, c, r) in enumerate(zip(products, clamped, roundtrip)): + if abs(r - c) / max(abs(c), 1e-10) > 0.01: + print(f" block={block_scales[i]:.4f} product={p:.4f} clamped={c:.4f} roundtrip={r:.4f} err={abs(r-c)/max(abs(c),1e-10):.4f}") + +print(f"\nTotal clamped: {total_clamped}, Total subnormal: {total_subnormal}, Max relative error: {max_rel_error:.4f}") + +# The real check: what's the float8_e4m3fn step size at various magnitudes? +print("\n=== Float8 E4M3 Step Sizes ===") +test_vals = [0.01, 0.1, 1.0, 10.0, 100.0, 448.0] +for v in test_vals: + f8 = torch.tensor(v, dtype=torch.float32).to(torch.float8_e4m3fn) + back = f8.to(torch.float32) + # Find next representable value + u8 = f8.view(torch.uint8) + next_u8 = u8 + 1 + next_f8 = next_u8.view(torch.float8_e4m3fn) + next_val = next_f8.to(torch.float32) + step = next_val - back + rel_step = step / back if back > 0 else 0 + print(f" value={v:.3f} → f8={back:.6f} → next={next_val:.6f} step={step:.6f} rel={rel_step:.4f}") diff --git a/diag_fold_real.py b/diag_fold_real.py new file mode 100644 index 00000000..80340eb6 --- /dev/null +++ b/diag_fold_real.py @@ -0,0 +1,96 @@ +""" +Critical check: weight_scale_2 values are ~4.65e-05 (TINY). +When folded: block_sf * 4.65e-05 → most products near zero → float8 can't represent +This is likely THE bug: folding a float8 scale by a tiny global scale produces +subnormal/zero values in float8. +""" +from safetensors import safe_open +import glob +import os +import torch + +MODEL_PATH = "/model" +ckpt_files = sorted(glob.glob(os.path.join(MODEL_PATH, "*.safetensors"))) + +# Get layer 0, expert 0 scales +for f in ckpt_files: + with safe_open(f, framework="pt") as st: + keys = list(st.keys()) + if any("layers.0.mlp.experts.0.gate_proj.weight_scale" in k for k in keys): + # Gate + gate_sf = st.get_tensor("model.layers.0.mlp.experts.0.gate_proj.weight_scale") + gate_gs = st.get_tensor("model.layers.0.mlp.experts.0.gate_proj.weight_scale_2") + # Up + up_sf = st.get_tensor("model.layers.0.mlp.experts.0.up_proj.weight_scale") + up_gs = st.get_tensor("model.layers.0.mlp.experts.0.up_proj.weight_scale_2") + # Down + down_sf = st.get_tensor("model.layers.0.mlp.experts.0.down_proj.weight_scale") + down_gs = st.get_tensor("model.layers.0.mlp.experts.0.down_proj.weight_scale_2") + + print("=" * 60) + print("LAYER 0, EXPERT 0 — Scale Analysis") + print("=" * 60) + + for name, sf, gs in [("gate", gate_sf, gate_gs), ("up", up_sf, up_gs), ("down", down_sf, down_gs)]: + sf_f32 = sf.float() + gs_f32 = gs.float() + product = sf_f32 * gs_f32 + product_clamped = product.clamp(0.0, 448.0) + folded_f8 = product_clamped.to(torch.float8_e4m3fn) + folded_back = folded_f8.float() + + n_total = product.numel() + n_clamped = (product > 448.0).sum().item() + n_zeroed = (folded_back == 0.0).sum().item() + n_nonzero_orig = (sf_f32 > 0).sum().item() + n_nonzero_folded = (folded_back > 0).sum().item() + + rel_err = (folded_back - product).abs() / product.clamp(min=1e-10) + + print(f"\n {name}_proj:") + print(f" block_sf: shape={list(sf.shape)} range=[{sf_f32.min():.4e}, {sf_f32.max():.4e}] unique_u8={torch.unique(sf.view(torch.uint8)).numel()}") + print(f" global_sf: {gs_f32.item():.6e}") + print(f" product (sf*gs): range=[{product.min():.4e}, {product.max():.4e}]") + print(f" folded (float8): range=[{folded_back.min():.4e}, {folded_back.max():.4e}]") + print(f" Clamped to 448: {n_clamped}/{n_total} ({100*n_clamped/n_total:.1f}%)") + print(f" Became zero: {n_zeroed}/{n_total} ({100*n_zeroed/n_total:.1f}%)") + print(f" Was nonzero → became zero: {n_nonzero_orig - n_nonzero_folded}/{n_nonzero_orig}") + print(f" Rel error: max={rel_err.max():.4f} mean={rel_err.mean():.4f}") + + # Show the float8 step size at the product magnitude + if product.max() > 0: + typical = product.median().item() + if typical > 0: + f8_typ = torch.tensor(typical, dtype=torch.float32).to(torch.float8_e4m3fn) + f8_back = f8_typ.float() + if f8_back > 0: + step = (f8_typ.view(torch.uint8) + 1).view(torch.float8_e4m3fn).float() - f8_back + print(f" Float8 step at median ({typical:.4e}): Δ={step.item():.4e} rel={step.item()/f8_back.item():.2%}") + + break + +# Now check: what if we DON'T fold, and instead pass global_scale as GEMM alpha? +print("\n" + "=" * 60) +print("ALTERNATIVE: Pass global_scale as GEMM alpha") +print("=" * 60) +print(""" +The fold is lossy because float8 can't represent the product range. +But if we DON'T fold, the CUTLASS GEMM needs a separate global scale mechanism. + +Option 1: Multiply the GEMM alpha by the weight's global_scale + - alpha already carries the activation global scale + - We could fold weight global scale into alpha: alpha_new = alpha * weight_gs + - BUT: alpha is a single scalar, weight_gs varies per-expert + - For grouped GEMM, each expert needs its own alpha + +Option 2: Keep block scales as-is (no fold), multiply output by global_scale + - After GEMM: output *= weight_global_scale + - This is exact (float32 multiply on bf16 output) + - Requires passing global_scale to nvfp4_mega_moe_full + +Option 3: Fold global_scale into the GEMM alpha per-expert + - In cutlass_grouped_nvfp4_gemm, each expert gets its own alpha + - alpha_expert = l1_global_scale * l1_weight_global_scale[expert_id] + - This is EXACT and doesn't lose precision + - The block scales stay at their original float8 values (no folding) +""") diff --git a/diag_issues.py b/diag_issues.py new file mode 100644 index 00000000..5c18b096 --- /dev/null +++ b/diag_issues.py @@ -0,0 +1,328 @@ +""" +Diagnostic script for NVFP4 mega_moe issues. + +Run on the B200 server. Checks: +1. Global scale folding precision (float8 round-trip) +2. L2 weight/SF orientation (transpose correctness) +3. EP aggregation contract (local vs all-reduce) +4. Folded scale float8 precision loss + +Usage: python diag_issues.py +""" + +import torch +import sys +import os + +# Try to import the model components +try: + from nvfp4_megamoe_kernel import ( + transform_nvfp4_weights_for_mega_moe, + stage_activation, + nvfp4_mega_moe_full, + ) + HAS_KERNEL = True +except ImportError: + HAS_KERNEL = False + print("WARNING: nvfp4_megamoe_kernel not importable, some checks will be skipped") + +def check_fold_precision(): + """Check 1: Float8 folding precision. + + The fold is: sf_f32 * gs → clamp(0, 448) → float8_e4m3fn + Question: are we silently destroying critical precision? + """ + print("=" * 60) + print("CHECK 1: Global Scale Folding Precision") + print("=" * 60) + + # Simulate realistic scale distributions + # NVFP4 block scales (float8_e4m3fn) are typically in range [0.06, 448] + # Global scales are per-expert float32 + + # Test with realistic ranges + for gs_val in [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]: + # Simulate 1000 block scales + sf = torch.rand(48, 64, 192) * 448 # Smaller for quantile perf + sf_f8 = sf.clamp(0.0, 448.0).to(torch.float8_e4m3fn) + sf_back = sf_f8.to(torch.float32) + + # Fold: product then cast back + product = sf_back * gs_val + product_clamped = product.clamp(0.0, 448.0) + folded_f8 = product_clamped.to(torch.float8_e4m3fn) + folded_back = folded_f8.to(torch.float32) + + # Compare against the "correct" product (sf_f32 * gs, no float8 intermediate) + correct_product = sf * gs_val + + # Count how many values are lost to clamping or zero + n_clamped = (product > 448.0).sum().item() + n_zeroed = (folded_back == 0.0).sum().item() - (correct_product == 0.0).sum().item() + + # Relative error + rel_err = (folded_back - correct_product).abs() / correct_product.clamp(min=1e-10) + max_rel = rel_err.max().item() + mean_rel = rel_err.mean().item() + p99_rel = rel_err.quantile(0.99).item() + + print(f" gs={gs_val:>8.3f}: clamped={n_clamped:>8d} zeroed={n_zeroed:>8d} " + f"max_rel={max_rel:.4f} mean_rel={mean_rel:.4f} p99_rel={p99_rel:.4f}") + +def check_l2_orientation(): + """Check 2: L2 weight/SF orientation. + + The down_proj maps intermediate→hidden. In PyTorch, weight is (out, in) = (hidden, intermediate). + After NVFP4 packing: (hidden, intermediate//2). + After transpose for CUTLASS col-major B: (intermediate//2, hidden). + + The CUTLASS GEMM computes: D = alpha * A @ B where A is (M, K) and B is (K, N). + K = intermediate (contraction dim), N = hidden (output dim). + Packed B is (K_half, N) in memory (column-major for CUTLASS). + + Question: is the transpose correct for the CUTLASS B layout? + """ + print("\n" + "=" * 60) + print("CHECK 2: L2 Weight/SF Orientation") + print("=" * 60) + + # Simulate L2 weight and SF + E, HIDDEN, INTER = 48, 7168, 3072 + K_half = INTER // 2 # 1536 + sf_K = INTER // 16 # 192 + + # Checkpoint shapes + w2_weight_shape = (E, HIDDEN, K_half) # (E, N_out, K_in//2) + w2_sf_shape = (E, HIDDEN, sf_K) # (E, N_out, sf_K) + + # After transpose + w2_weight_transposed = (E, K_half, HIDDEN) # (E, K_half, N) — CUTLASS col-major B + w2_sf_transposed = (E, sf_K, HIDDEN) # (E, sf_K, N) + + # CUTLASS expects for the grouped GEMM: + # weights: (E, K_half, N) ✓ + # weight_sf: (E, sf_K, N) — but which is K_sf and which is N? + # The remap kernel gets: MN=N=HIDDEN, K_sf=INTER//16=192, col_major_src=true + # Source is (K_sf, MN) = (192, 7168) row-major ✓ + + print(f" Checkpoint w2_weight: {w2_weight_shape}") + print(f" Checkpoint w2_sf: {w2_sf_shape}") + print(f" After transpose w2_weight: {w2_weight_transposed}") + print(f" After transpose w2_sf: {w2_sf_transposed}") + print(f" CUTLASS expects B: (K_half={K_half}, N={HIDDEN})") + print(f" CUTLASS expects SFB: (K_sf={sf_K}, N={HIDDEN})") + print(f" ✓ Shapes match") + + # BUT: check if the DATA is semantically correct + # The transpose swaps (N, K_half) → (K_half, N) + # For the weight, row i of (N, K_half) becomes column i of (K_half, N) + # In row-major, element [i,j] of (N, K_half) goes to offset i*K_half + j + # After transpose, it's at offset j*N + i in (K_half, N) + # CUTLASS column-major B reads logical (n,k) at offset n + k*N + # Where n is the output dim (hidden) and k is the contraction dim (intermediate) + # For packed FP4: k ranges 0..K_half-1 (2 values per byte) + # So logical (n, k_half) at offset n + k_half * N + # Our data: element at memory offset k_half * N + n (row-major (K_half, N)) + # = k_half * N + n = n + k_half * N ← SAME ✓ + print(f" ✓ CUTLASS column-major stride matches our row-major (K_half, N) layout") + +def check_ep_aggregation(): + """Check 3: EP aggregation contract. + + Each rank computes y = sum over local experts of (routing_weight * expert_output). + Then all-reduce sums across EP ranks. + + The contract is: final_y = sum_ranks(y_rank) + + Question: is the local y correctly computed such that the all-reduce gives the right answer? + """ + print("\n" + "=" * 60) + print("CHECK 3: EP Aggregation Contract") + print("=" * 60) + + # Simulate: 2 EP ranks, 4 total experts, topk=2 + # Rank 0 has experts 0,1; Rank 1 has experts 2,3 + # Token is routed to experts 0 and 2 (one per rank) + + # On Rank 0: slot for expert 0, slot_weight * l2_output → index_add to y + # On Rank 1: slot for expert 2, slot_weight * l2_output → index_add to y + # All-reduce: y_final = y_rank0 + y_rank1 ✓ + + # POTENTIAL ISSUE: what if the same token is routed to multiple experts + # on the same rank? index_add_ handles this correctly (sums in-place). + + # POTENTIAL ISSUE: what if a token has NO experts on a rank? + # y stays at 0 for that token → correct, other ranks contribute. + + # POTENTIAL ISSUE: is slot_weight correctly applied? + # In nvfp4_mega_moe_full: + # y.index_add_(0, slot_token, l2_slots * slot_weight.unsqueeze(1)) + # l2_slots is (num_slots, HIDDEN) bf16 + # slot_weight is (num_slots,) float32, unsqueezed to (num_slots, 1) + # So each slot output is scaled by its routing weight before accumulating. + # This is correct: final = sum_k(w_k * expert_k(x)) + + print(" ✓ Local index_add_ + all-reduce contract is correct") + print(" ✓ slot_weight applied before index_add (correct)") + print(" NOTE: This assumes all-reduce uses SUM (not AVG). Verify with torch.distributed.") + + # Check the vllm code uses all_reduce (sum by default) + # torch.distributed.all_reduce defaults to ReduceOp.SUM ✓ + +def check_fold_vs_nofold(): + """Check 4: What happens if global scale is NOT folded? + + If weight_scale_2 is not folded into the block scales, the weights are + effectively used without their global scaling factor. This would produce + finite but semantically garbage output — exactly the symptom. + """ + print("\n" + "=" * 60) + print("CHECK 4: Global Scale Folding Verification") + print("=" * 60) + + # The fold happens in transform_nvfp4_weights_for_mega_moe: + # 1. sf_f32 = weight_scale.to(float32) + # 2. sf_f32 *= weight_scale_2 (global scale) + # 3. sf_out = sf_f32.clamp(0, 448).to(float8_e4m3fn) + + # If weight_scale_2 is None (not provided), the fold is skipped + # and only block scales are used. This would be a bug. + + # Check: is weight_scale_2 actually non-None when finalize_weights is called? + # From the code: + # transform_nvfp4_weights_for_mega_moe( + # ..., l1_weight_scale_2=self.w13_weight_scale_2.data.contiguous(), ...) + # self.w13_weight_scale_2 is initialized as nn.Parameter(torch.zeros(num_local_experts, 2)) + # It's loaded from checkpoint in weight_loader (shard_id w1→[e,0], w3→[e,1]) + + # If the checkpoint doesn't contain weight_scale_2 for experts, + # the parameter stays at zeros. Folding with gs=0 → all scales become 0 → garbage. + + print(" If weight_scale_2 is all zeros (not loaded from checkpoint):") + sf = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0]) + gs_zero = 0.0 + folded = (sf * gs_zero).clamp(0, 448).to(torch.float8_e4m3fn) + print(f" sf={sf.tolist()} * gs=0 → folded={folded.to(torch.float32).tolist()}") + print(" ALL SCALES GO TO ZERO → all outputs are zero → garbage") + + print("\n If weight_scale_2 is correctly loaded (typical values):") + gs = torch.tensor([0.5, 1.0, 2.0, 5.0, 10.0]) + for g in gs: + folded = (sf * g).clamp(0, 448).to(torch.float8_e4m3fn) + correct = sf * g + rel_err = ((folded.to(torch.float32) - correct).abs() / correct).mean() + print(f" gs={g:.1f}: mean_rel_err={rel_err:.4f}") + +def check_l2_sf_transpose_semantics(): + """Check 5: After transposing L2 SF, is the data in the right layout? + + The w2_weight_scale in checkpoint is (E, N, sf_K) = (E, hidden, inter//16). + This means: for each expert, for each output row (hidden dim), we have sf_K block scales + along the input dimension. + + After transpose: (E, sf_K, N) = (E, inter//16, hidden). + This means: for each expert, for each block along the input dim, we have N=hidden scale values. + + CUTLASS SFB is (K_sf, N) where K_sf is the contraction dim's scale groups. + K_sf = K // 16 = inter // 16. N = hidden. + + The CUTLASS remap expects col_major_src=True, so it reads src[k_sf * N + m]. + With N=hidden and K_sf=inter//16, this accesses the (E, inter//16, hidden) tensor correctly. + + BUT WAIT: The CUTLASS SFB layout is defined for the B matrix which is ColumnMajor. + For ColumnMajor B with shape (N, K), the SFB layout might have a different + semantic mapping than what we're providing. + + Let me check: does CUTLASS SFB index by (N_idx, K_sf_idx) or (K_sf_idx, N_idx)? + """ + print("\n" + "=" * 60) + print("CHECK 5: L2 SF Transpose Semantics (Deep Dive)") + print("=" * 60) + + # The key question: after the transpose, does SFB[i, j] contain the right value? + # + # Original (checkpoint): weight_scale[E, hidden_row, sf_k_block] + # = the block scale for expert E, output row hidden_row, input block sf_k_block + # + # The GEMM operation: Y = X @ W where W is (K, N) = (inter, hidden) + # SFB should be: for each output column n and each input block k_sf, + # SFB[n, k_sf] = scale for column n, input block k_sf + # OR (depending on CUTLASS convention): + # SFB[k_sf, n] = scale for input block k_sf, output column n + # + # In CUTLASS NVFP4, SFB has the same (K, N) structure as B. + # B is ColumnMajor (N, K), so B[n, k] is at memory n + k * N. + # SFB should follow the same (N, K_sf) → ColumnMajor → (K_sf, N) row-major in memory. + # + # Our source (after transpose): (E, sf_K, N) = (E, K_sf, N) row-major + # Element [e, k_sf, n] = original [e, n, k_sf] = checkpoint scale for expert e, output n, input block k_sf + # The remap reads: src[k_sf * N + n] (col_major_src=true) + # = element [e, k_sf, n] = correct scale for (n, k_sf) in the B matrix + # ✓ This is correct! + + print(" L2 SF transpose semantics are correct") + print(" After transpose: (E, K_sf, N) with col_major_src=True") + print(" remap reads src[k_sf * N + n] = original scale[e, n, k_sf] ✓") + +def check_w13_gate_up_split(): + """Check 6: Is the gate/up split for w13 scale_2 folding aligned + with the actual weight layout after transpose? + + w13_weight shape: (E, 2*INTER, HIDDEN//2) + w13_weight_scale shape: (E, 2*INTER, HIDDEN//16) + + The fold splits: gate = first INTER rows, up = last INTER rows + Then applies gs[:,0] to gate, gs[:,1] to up + + After transpose: + w13_weight: (E, HIDDEN//2, 2*INTER) + w13_sf: (E, HIDDEN//16, 2*INTER) + + The gate/up split is now along the LAST dim (N), not the middle. + But the fold happens BEFORE the transpose, so the split is correct. + After transpose, the gate portion is columns 0..INTER-1 and up is INTER..2*INTER-1. + This is still semantically correct for the CUTLASS GEMM. + """ + print("\n" + "=" * 60) + print("CHECK 6: w13 Gate/Up Split Alignment") + print("=" * 60) + print(" Fold happens before transpose → gate/up split is on dim 1 (N)") + print(" After transpose, split is on dim 2 (N) — last dimension") + print(" CUTLASS GEMM sees N=2*INTER with gate first, up second ✓") + print(" The folded scales correctly reflect gate_gs and up_gs ✓") + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("WARNING: No CUDA — some checks will be approximate") + + check_fold_precision() + check_l2_orientation() + check_ep_aggregation() + check_fold_vs_nofold() + check_l2_sf_transpose_semantics() + check_w13_gate_up_split() + + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + print(""" + Most likely suspects for "finite but garbage" output: + + 1. weight_scale_2 not loaded → all-zero global scales → folded sf = 0 + CHECK: Print w13_weight_scale_2 and w2_weight_scale_2 after loading + + 2. Float8 folding precision: 12-95% relative error for small global scales + This is a QUALITY issue, not a garbage issue + BUT: if global scales are very small (<<1), entire scale groups zero out + + 3. L2 weight/SF: shapes and semantics look correct after analysis + The transpose + CUTLASS col-major + SFB remap are consistent + + 4. EP aggregation: contract looks correct (local sum + all_reduce) + + ACTION ITEMS: + a) Run the model with debug prints showing weight_scale_2 values + b) Check if any folded scales clamp to 0 or 448 (precision ceiling) + c) Compare folded sf values against reference (unfolded) computation + d) Test with a single expert to isolate EP issues +""") diff --git a/diag_keys.py b/diag_keys.py new file mode 100644 index 00000000..e7c63f04 --- /dev/null +++ b/diag_keys.py @@ -0,0 +1,39 @@ +"""Find ALL weight_scale_2 keys in the checkpoint for layer 0 experts.""" +from safetensors import safe_open +import glob +import os + +MODEL_PATH = "/model" +ckpt_files = sorted(glob.glob(os.path.join(MODEL_PATH, "*.safetensors"))) + +# Collect ALL keys that mention layer 0 experts and scale +scale_keys = [] +for f in ckpt_files: + with safe_open(f, framework="pt") as st: + for key in st.keys(): + if "layers.0" in key and "experts.0" in key and "scale" in key.lower(): + val = st.get_tensor(key) + scale_keys.append((key, list(val.shape), str(val.dtype), val.float().min().item(), val.float().max().item())) + +scale_keys.sort() +for k, s, d, mn, mx in scale_keys: + print(f" {k} shape={s} dtype={d} range=[{mn:.4e}, {mx:.4e}]") + +print(f"\nTotal: {len(scale_keys)} scale keys for layer 0 expert 0") + +# Also find gate_proj and up_proj weight_scale_2 keys +print("\n--- All weight_scale_2 keys with gate/up/down for layer 0 ---") +ws2_keys = [] +for f in ckpt_files: + with safe_open(f, framework="pt") as st: + for key in st.keys(): + if "layers.0" in key and "weight_scale_2" in key: + val = st.get_tensor(key) + ws2_keys.append((key, list(val.shape), str(val.dtype), val.float().min().item(), val.float().max().item())) + +ws2_keys.sort() +for k, s, d, mn, mx in ws2_keys[:10]: + print(f" {k} shape={s} dtype={d} range=[{mn:.4e}, {mx:.4e}]") +if len(ws2_keys) > 10: + print(f" ... and {len(ws2_keys)-10} more") +print(f"Total: {len(ws2_keys)} weight_scale_2 keys for layer 0") diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index 124a8205..76e9bc00 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -61,6 +61,7 @@ def cutlass_grouped_nvfp4_gemm( slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs slot_token=None, # (num_slots,) int64 — per-slot token indices (default: arange) alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale) + per_expert_alpha=None, # (E_per_rank,) float32 — per-expert alpha overrides scalar alpha ): """Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4. @@ -71,6 +72,11 @@ def cutlass_grouped_nvfp4_gemm( For L1: x_fp4 has num_tokens rows, slot_token maps slots→rows. For L2: x_fp4 has num_slots rows, slot_token is just arange(num_slots). + If per_expert_alpha is provided, each expert uses its own alpha value + (activation_global_scale * weight_global_scale[expert]) instead of the + scalar alpha. This preserves full float32 precision — no lossy float8 + folding of weight global scales. + Returns: slot_out: (num_slots, N) bfloat16 — per-slot GEMM results slot_token: (num_slots,) int64 — token index for each slot @@ -100,7 +106,7 @@ def cutlass_grouped_nvfp4_gemm( if MEGA_MOE_DEBUG: print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} " - f"experts={num_experts}") + f"experts={num_experts} per_expert_alpha={'yes' if per_expert_alpha is not None else 'no'}") slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device) @@ -116,9 +122,12 @@ def cutlass_grouped_nvfp4_gemm( expert_w_sf = weight_sf[e] M_expert = e_idx.shape[0] + # Per-expert alpha: activation_gs * weight_gs (float32, no precision loss) + expert_alpha = float(per_expert_alpha[e]) if per_expert_alpha is not None else alpha + if MEGA_MOE_DEBUG and e < 3 and M_expert > 0: print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} " - f"w shape={expert_w.shape}") + f"w shape={expert_w.shape} alpha={expert_alpha:.4e}") # Shape/dtype contract asserts — SFB bugs hide in silent shape mismatches assert expert_x.shape == (M_expert, K // 2), f"expert_x shape {expert_x.shape} != ({M_expert}, {K // 2})" @@ -132,14 +141,14 @@ def cutlass_grouped_nvfp4_gemm( expert_x, expert_x_sf, expert_w, expert_w_sf, M_expert, N, K, - alpha=alpha, + alpha=expert_alpha, ) if MEGA_MOE_DEBUG: if torch.isnan(expert_out).any() or torch.isinf(expert_out).any(): raise RuntimeError( f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. " - f"M={M_expert} N={N} K={K}") + f"M={M_expert} N={N} K={K} alpha={expert_alpha:.4e}") slot_out[e_idx] = expert_out diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 72576c24..1daadfcf 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -56,8 +56,6 @@ def unpack_ue4m3_u32(x_u32): return out # CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell) -# Primary path: uses CUTLASS MainloopSm100TmaUmmaWarpSpecializedBlockScaled -# which invokes mxf8f6f4.block_scale tensor core instructions directly. MEGA_MOE_USE_CUTLASS = int(os.environ.get("MEGA_MOE_USE_CUTLASS", "1")) try: @@ -97,13 +95,25 @@ def nvfp4_mega_moe_l1( l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs slot_token, # (num_slots,) int64 — token index per slot + l1_global_sf, # (E_per_rank, 2) or (E_per_rank,) float32 — weight global scales alpha=1.0, # fp32 scalar from stage_activation global scale ): """L1 GEMM: gate_up_proj — slot-based, no routing weights. - Takes pre-built slot mapping (slot_expert_ids, slot_token) from the outer - routing logic. Returns (slot_out, slot_token) where each slot is one - (token, topk) pair. + Global scale is NOT folded into block scales. Instead, it's applied as a + per-expert multiplier to the GEMM alpha: alpha_expert = alpha * global_sf[expert]. + For L1 with gate+up: gate and up share one GEMM but may have different global scales. + Since the GEMM produces gate|up in one shot, we use a single alpha per expert. + Post-GEMM, we apply the gate/up ratio correction if they differ. + + Actually, for simplicity and correctness: we use the gate global scale as alpha + and correct the up portion after GEMM. But since gate and up global scales + are typically identical in practice, we just use the geometric mean. + + CLEANER APPROACH: use per-expert alpha directly in the grouped GEMM. + The grouped GEMM iterates per expert, so each expert can have its own alpha. + For L1 with separate gate/up global scales, we use the geometric mean + and then apply a correction factor to the up portion. """ K_half = x_fp4.shape[1] K = K_half * 2 @@ -116,13 +126,36 @@ def nvfp4_mega_moe_l1( w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales assert w_sf_fp8.dtype == torch.float8_e4m3fn, f"l1_scales after unpack dtype={w_sf_fp8.dtype}" + # Compute per-expert alpha: activation_gs * weight_gs + # For L1 with (E, 2) gate/up global scales, use geometric mean per expert + if l1_global_sf.dim() == 2 and l1_global_sf.shape[1] == 2: + # gate_gs and up_gs per expert — use gate_gs for the GEMM alpha, + # then correct the up half post-GEMM + l1_gate_gs = l1_global_sf[:, 0] # (E,) float32 + l1_up_gs = l1_global_sf[:, 1] # (E,) float32 + per_expert_alpha = alpha * l1_gate_gs # (E,) float32 + up_correction = l1_up_gs / l1_gate_gs # (E,) float32 — ratio to apply to up half + else: + per_expert_alpha = alpha * l1_global_sf # (E,) float32 + up_correction = None + slot_out, slot_token = cutlass_grouped_nvfp4_gemm( x_fp4, x_sf_fp8, l1_weights, w_sf_fp8, - slot_expert_ids, # 1D per-slot expert IDs - slot_token, # 1D per-slot token indices - alpha=alpha, + slot_expert_ids, + slot_token, + per_expert_alpha=per_expert_alpha, ) + + # Apply up correction if gate/up global scales differ + if up_correction is not None: + gate_N = N // 2 + # For each slot, apply the correction to the up half + # slot_out is (num_slots, N) — up half is [:, gate_N:] + # Correction factor is per-expert: up_correction[slot_expert_ids] + correction = up_correction[slot_expert_ids].unsqueeze(1) # (num_slots, 1) + slot_out[:, gate_N:] = slot_out[:, gate_N:] * correction.to(slot_out.dtype) + print(f"[L1-GEMM-OUT] slots={slot_out.shape[0]} N={N} amax={slot_out.abs().max().item():.4e} mean={slot_out.float().mean().item():.4e}") return slot_out, slot_token @@ -132,13 +165,15 @@ def nvfp4_mega_moe_l2( x_sf, # (num_slots, sf_k_groups) float8_e4m3fn l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major - slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs (from L1 routing) - slot_token, # (num_slots,) int64 — token index per slot (from L1) + slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs + slot_token, # (num_slots,) int64 — token index per slot + l2_global_sf, # (E_per_rank,) float32 — weight global scales alpha=1.0, # fp32 scalar from stage_activation global scale ): """L2 GEMM: down_proj — slot-based, no routing weights. - Reuses the same slot mapping from L1 (same slot_token and slot_expert_ids). + Per-expert alpha = activation_global_scale * weight_global_scale[expert]. + This preserves full float32 precision — no lossy float8 folding. """ K_half = x_fp4.shape[1] K = K_half * 2 @@ -151,11 +186,14 @@ def nvfp4_mega_moe_l2( w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales assert w_sf_fp8.dtype == torch.float8_e4m3fn, f"l2_scales after unpack dtype={w_sf_fp8.dtype}" + # Per-expert alpha: activation_gs * weight_gs + per_expert_alpha = alpha * l2_global_sf # (E,) float32 + slot_out, _ = cutlass_grouped_nvfp4_gemm( x_fp4, x_sf_fp8, l2_weights, w_sf_fp8, - slot_expert_ids, # 1D per-slot expert IDs — GEMM handles directly - alpha=alpha, + slot_expert_ids, + per_expert_alpha=per_expert_alpha, ) return slot_out # (num_slots, HIDDEN) bfloat16 @@ -236,8 +274,8 @@ def stage_activation(x_bf16): def nvfp4_mega_moe_full( y, # output tensor (num_tokens, HIDDEN) bfloat16 - transformed_l1_weights, # (l1_w, l1_sf) tuple from finalize_weights - transformed_l2_weights, # (l2_w, l2_sf) tuple from finalize_weights + transformed_l1_weights, # (l1_w, l1_sf, l1_global_sf) from finalize_weights + transformed_l2_weights, # (l2_w, l2_sf, l2_global_sf) from finalize_weights symm_buffer, # SymmBuffer from get_symm_buffer activation_clamp=None, # optional clamp value (unused in NVFP4) fast_math=False, # fast math flag (unused in NVFP4) @@ -246,10 +284,10 @@ def nvfp4_mega_moe_full( Slot-based pipeline (routing weights applied ONCE at final scatter): 1. Read staged activation from symm_buffer - 2. L1 GEMM → slot output (num_slots, 2*INTER) — NO routing weights + 2. L1 GEMM → slot output (num_slots, 2*INTER) — per-expert alpha 3. SiLU + Mul PER SLOT (nonlinearity before combining expert paths) 4. Quantize activated slots → FP4 - 5. L2 GEMM → slot output (num_slots, HIDDEN) — NO routing weights + 5. L2 GEMM → slot output (num_slots, HIDDEN) — per-expert alpha 6. Final scatter: y.index_add_(0, slot_token, slot_weight * l2_slots) Single routing weight application. """ @@ -264,9 +302,9 @@ def nvfp4_mega_moe_full( y.zero_() return - # Unpack transformed weights - l1_w, l1_sf = transformed_l1_weights - l2_w, l2_sf = transformed_l2_weights + # Unpack transformed weights (now includes global_sf) + l1_w, l1_sf, l1_global_sf = transformed_l1_weights + l2_w, l2_sf, l2_global_sf = transformed_l2_weights # Expert sanity check — are experts actually distinct? if not getattr(nvfp4_mega_moe_full, '_expert_sanity', False): @@ -276,6 +314,8 @@ def nvfp4_mega_moe_full( sf_sample = l1_sf[e].to(torch.float32)[:4, :4] print(f"[EXPERT-SANITY e={e}] w_bytes[:8,:8]={w_sample.flatten().tolist()[:16]}") print(f"[EXPERT-SANITY e={e}] sf[:4,:4]={sf_sample.flatten().tolist()[:8]}") + print(f"[EXPERT-SANITY e={e}] l1_global_sf={l1_global_sf[e].tolist()}") + print(f"[EXPERT-SANITY e={e}] l2_global_sf={l2_global_sf[e].tolist()}") # Step 1: Read staged activation from symm_buffer x_fp4 = symm_buffer.x[:num_tokens] @@ -287,7 +327,8 @@ def nvfp4_mega_moe_full( _x_sf_f32 = x_sf.to(torch.float32) _igs = l1_global_scale if isinstance(l1_global_scale, float) else l1_global_scale.item() if hasattr(l1_global_scale, 'item') else float(l1_global_scale) if MEGA_MOE_DEBUG: - print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}]") + print(f"[ALPHA L1] activation_gs={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}]") + print(f"[ALPHA L1] l1_global_sf range [{l1_global_sf.min().item():.4e}, {l1_global_sf.max().item():.4e}]") # Convert global expert IDs to local expert IDs num_experts_per_rank = l1_w.shape[0] @@ -316,10 +357,10 @@ def nvfp4_mega_moe_full( y.zero_() return - # Ensure alpha is a plain Python float (C extension can't handle torch scalars) + # Ensure alpha is a plain Python float for the base activation global scale l1_alpha = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale - # Shape consistency asserts — catch mismatched slot mappings early + # Shape consistency asserts assert slot_expert_local.ndim == 1 assert slot_token.ndim == 1 assert slot_weight.ndim == 1 @@ -327,21 +368,11 @@ def nvfp4_mega_moe_full( assert slot_token.numel() == num_slots assert slot_weight.numel() == num_slots - # SFB weight scales are remapped per-expert inside CUTLASS on each call. - # ───────────────────────────────────────────────────────────────────── - # NO PREPACK CACHE — see README for rationale. - # DO NOT add a prepack cache. Previous attempts caused: - # - OOM: ~1.75 GiB per prepacked tensor × 61 layers = 214 GiB - # - Peak memory 2× during torch.stack before eviction - # - CUDA graph use-after-free on evicted entries - # - M_for_layout=128 assumption (unverified M-independence) - # The SFB remap is a small scatter kernel (~µs) — not the bottleneck. - # ───────────────────────────────────────────────────────────────────── - - # Step 2: L1 GEMM — slot-based, no routing weights + # Step 2: L1 GEMM — slot-based, per-expert alpha l1_slots, _ = nvfp4_mega_moe_l1( x_fp4, x_sf, l1_w, l1_sf, slot_expert_local, slot_token, + l1_global_sf=l1_global_sf, alpha=l1_alpha, ) # (num_slots, 2*INTER) bfloat16 @@ -374,12 +405,14 @@ def nvfp4_mega_moe_full( if MEGA_MOE_DEBUG: _l1sf_f32 = l1_sf_out.to(torch.float32) _l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item() - print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]") + print(f"[ALPHA L2] activation_gs={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]") + print(f"[ALPHA L2] l2_global_sf range [{l2_global_sf.min().item():.4e}, {l2_global_sf.max().item():.4e}]") - # Step 5: L2 GEMM — slot-based, no routing weights + # Step 5: L2 GEMM — slot-based, per-expert alpha l2_slots = nvfp4_mega_moe_l2( l1_fp4, l1_sf_out, l2_w, l2_sf, slot_expert_local, slot_token, + l2_global_sf=l2_global_sf, alpha=l2_alpha, ) # (num_slots, HIDDEN) bfloat16 diff --git a/src/nvfp4_megamoe_kernel/weight_transform.py b/src/nvfp4_megamoe_kernel/weight_transform.py index c56345df..78aa561e 100644 --- a/src/nvfp4_megamoe_kernel/weight_transform.py +++ b/src/nvfp4_megamoe_kernel/weight_transform.py @@ -5,12 +5,12 @@ Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float3 into the format expected by the CUTLASS block-scaled GEMM kernel: - Packed FP4 weights (int8, K-major) - UE4M3 block scales (float8_e4m3fn, row-major — CUTLASS SF remap handles interleaving) +- float32 global scales (NOT folded into block scales — passed separately for per-expert alpha) -Weight scales are returned as float8_e4m3fn (NOT packed uint32). The CUTLASS GEMM -consumes float8 scales directly; only activation scales from the staging kernel come -as uint32 and need unpack_ue4m3_u32. - -This replaces deep_gemma.mega.transform_nvfp4_weights_for_mega_moe. +Previous versions folded weight_scale_2 into block scales via float8 round-trip, which caused +25% relative error (product of ~56-448 block_sf × ~4.65e-05 global_sf lands in the low-precision +zone of float8_e4m3fn where step size is 25%). The global scale is now applied as a per-expert +multiplier to the GEMM alpha, preserving full float32 precision. Call signature matches the nightly vLLM deepseek_v4.py finalize_weights: transform_nvfp4_weights_for_mega_moe( @@ -24,134 +24,80 @@ Call signature matches the nightly vLLM deepseek_v4.py finalize_weights: import torch -def _fold_global_scale( - weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn - weight_scale_2: torch.Tensor, # (E,) or (E, 2) or scalar float32 -) -> torch.Tensor: - """Fold global scale into block scales: UE4M3 * FP32 → float32. - - For fused projections (w13 = gate+up), weight_scale_2 is (E, 2): - scale_2[e, 0] applies to gate_proj rows, scale_2[e, 1] applies to up_proj rows. - N is split: gate = weight_scale[:, :N//2, :], up = weight_scale[:, N//2:, :] - For single projections (w2), weight_scale_2 is (E,) or scalar. - """ - sf_f32 = weight_scale.to(torch.float32) - gs = weight_scale_2.to(torch.float32) - - if gs.numel() == 1: - sf_f32 = sf_f32 * gs - elif gs.dim() == 2 and gs.shape[1] == 2: - # Fused projection: (E, 2) — gate and up have separate global scales - # weight_scale is (E, N, K//16), N = gate_N + up_N - gate_N = sf_f32.shape[1] // 2 - gs_gate = gs[:, 0].unsqueeze(-1) # (E, 1) - gs_up = gs[:, 1].unsqueeze(-1) # (E, 1) - sf_f32[:, :gate_N, :] = sf_f32[:, :gate_N, :] * gs_gate.unsqueeze(-1) - sf_f32[:, gate_N:, :] = sf_f32[:, gate_N:, :] * gs_up.unsqueeze(-1) - else: - # Per-expert global scale — broadcast multiply - while gs.dim() < sf_f32.dim(): - gs = gs.unsqueeze(-1) - sf_f32 = sf_f32 * gs.expand_as(sf_f32) - - return sf_f32 - - -def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor: - """Pack 4 UE4M3 (float8_e4m3fn) values into one uint32.""" - sf_u8 = sf.view(torch.uint8) - assert sf_u8.shape[-1] % 4 == 0, f"Last dim {sf_u8.shape[-1]} not divisible by 4" - packed = (sf_u8[..., 0::4].to(torch.int32) | - (sf_u8[..., 1::4].to(torch.int32) << 8) | - (sf_u8[..., 2::4].to(torch.int32) << 16) | - (sf_u8[..., 3::4].to(torch.int32) << 24)) - return packed.contiguous() - - def transform_nvfp4_weights_for_mega_moe( l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1 l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2 -) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: +) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], + tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Transform NVFP4 weights for the CUTLASS block-scaled GEMM. - - Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights. - + + NO LONGER FOLDS GLOBAL SCALES INTO BLOCK SCALES. + Folding block_sf (float8) × global_sf (float32) → float8 loses ~25% precision + because the product lands in the low-precision zone of float8_e4m3fn. + Instead, global scales are returned separately and applied as per-expert GEMM alpha. + Args: l1_tuple: (w13_weight, w13_weight_scale) — gate_up proj l2_tuple: (w2_weight, w2_weight_scale) — down proj l1_weight_scale_2: global scale for L1 (float32) + Shape (E, 2) for gate+up, or (E,) per-expert, or scalar l2_weight_scale_2: global scale for L2 (float32) - + Shape (E,) per-expert, or scalar + Returns: - ((l1_weight, l1_sf_packed), (l2_weight, l2_sf_packed)) + ((l1_weight, l1_sf, l1_global_sf), (l2_weight, l2_sf, l2_global_sf)) + where global_sf is (E,) float32 — the geometric mean of gate/up for L1, + or the per-expert global scale for L2. + The caller must apply global_sf as a per-expert multiplier to the GEMM alpha. """ l1_weight, l1_weight_scale = l1_tuple l2_weight, l2_weight_scale = l2_tuple - # DEBUG: check raw scales before folding - l1_sf_f32_raw = l1_weight_scale.to(torch.float32) - l1_gs_raw = l1_weight_scale_2.to(torch.float32) if l1_weight_scale_2 is not None else None - if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug', False): - transform_nvfp4_weights_for_mega_moe._sf_debug = True - print(f"[SF-DEBUG] raw l1_sf dtype={l1_weight_scale.dtype} range=[{l1_sf_f32_raw.min().item():.4e}, {l1_sf_f32_raw.max().item():.4e}] " - f"unique_raw={torch.unique(l1_weight_scale.view(torch.uint8)).numel()}") - if l1_gs_raw is not None: - print(f"[SF-DEBUG] l1_gs dtype={l1_weight_scale_2.dtype} shape={tuple(l1_weight_scale_2.shape)} " - f"range=[{l1_gs_raw.min().item():.4e}, {l1_gs_raw.max().item():.4e}] " - f"unique_gs={torch.unique(l1_gs_raw).numel()}") - if l1_gs_raw.dim() == 2 and l1_gs_raw.shape[1] == 2: - print(f"[SF-DEBUG] gate gs unique={torch.unique(l1_gs_raw[:, 0]).numel()} " - f"up gs unique={torch.unique(l1_gs_raw[:, 1]).numel()}") - - # DEBUG: check L2 scales - l2_sf_f32_raw = l2_weight_scale.to(torch.float32) - l2_gs_raw = l2_weight_scale_2.to(torch.float32) if l2_weight_scale_2 is not None else None - if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_l2', False): - transform_nvfp4_weights_for_mega_moe._sf_debug_l2 = True - print(f"[SF-DEBUG-L2] raw l2_sf dtype={l2_weight_scale.dtype} range=[{l2_sf_f32_raw.min().item():.4e}, {l2_sf_f32_raw.max().item():.4e}] " - f"unique_raw={torch.unique(l2_weight_scale.view(torch.uint8)).numel()}") - if l2_gs_raw is not None: - print(f"[SF-DEBUG-L2] l2_gs dtype={l2_weight_scale_2.dtype} shape={tuple(l2_weight_scale_2.shape)} " - f"range=[{l2_gs_raw.min().item():.4e}, {l2_gs_raw.max().item():.4e}] " - f"unique_gs={torch.unique(l2_gs_raw).numel()}") - - # Post-fold diagnostics — one-time - if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_fold', False): - transform_nvfp4_weights_for_mega_moe._sf_debug_fold = True - l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2) if l1_weight_scale_2 is not None else l1_weight_scale.to(torch.float32) - l1_sf_out_check = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) - l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2) if l2_weight_scale_2 is not None else l2_weight_scale.to(torch.float32) - l2_sf_out_check = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) - print(f"[SF-FOLD] l1 pre-fold unique_u8={torch.unique(l1_weight_scale.view(torch.uint8)).numel()} " - f"post-fold unique_u8={torch.unique(l1_sf_out_check.view(torch.uint8)).numel()} " - f"range=[{l1_sf_folded.min().item():.4e}, {l1_sf_folded.max().item():.4e}]") - print(f"[SF-FOLD] l2 pre-fold unique_u8={torch.unique(l2_weight_scale.view(torch.uint8)).numel()} " - f"post-fold unique_u8={torch.unique(l2_sf_out_check.view(torch.uint8)).numel()} " - f"range=[{l2_sf_folded.min().item():.4e}, {l2_sf_folded.max().item():.4e}]") - - # Fold global scales into block scales - # The logical_widths branch was wrong: it treated gs as per-projection - # scalars and only used experts 0 and 1's scales for ALL experts. - # The else branch correctly broadcasts each expert's own global scale. + # Extract global scales as per-expert float32 vectors + # L1: gate/up have separate global scales — store both + # The caller (nvfp4_mega_moe_full) will apply the right one per-expert if l1_weight_scale_2 is not None: - l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2) + l1_gs = l1_weight_scale_2.to(torch.float32) + if l1_gs.dim() == 2 and l1_gs.shape[1] == 2: + # (E, 2) — gate_gs and up_gs separate + # For L1 alpha, use the geometric mean (close enough since gate and up + # global scales are typically similar). Actually, we need BOTH because + # the GEMM produces gate and up in one shot. + # Better: just store (E, 2) and let the caller apply post-GEMM scaling. + l1_global_sf = l1_gs # (E, 2) float32 + else: + l1_global_sf = l1_gs # (E,) float32 else: - l1_sf_folded = l1_weight_scale.to(torch.float32) + l1_global_sf = torch.ones(l1_weight.shape[0], dtype=torch.float32, device=l1_weight.device) if l2_weight_scale_2 is not None: - l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2) + l2_gs = l2_weight_scale_2.to(torch.float32) + l2_global_sf = l2_gs # (E,) or scalar → broadcast to (E,) + if l2_global_sf.dim() == 0: + l2_global_sf = l2_global_sf.expand(l2_weight.shape[0]) else: - l2_sf_folded = l2_weight_scale.to(torch.float32) + l2_global_sf = torch.ones(l2_weight.shape[0], dtype=torch.float32, device=l2_weight.device) - # Clamp and convert back to UE4M3 - l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() - l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() + # Debug: one-time diagnostic + if not getattr(transform_nvfp4_weights_for_mega_moe, '_diag', False): + transform_nvfp4_weights_for_mega_moe._diag = True + print(f"[WT-XFORM] L1 block_sf range=[{l1_weight_scale.float().min():.4e}, " + f"{l1_weight_scale.float().max():.4e}] unique={torch.unique(l1_weight_scale.view(torch.uint8)).numel()}") + print(f"[WT-XFORM] L1 global_sf: shape={tuple(l1_global_sf.shape)} " + f"range=[{l1_global_sf.min():.4e}, {l1_global_sf.max():.4e}]") + print(f"[WT-XFORM] L2 block_sf range=[{l2_weight_scale.float().min():.4e}, " + f"{l2_weight_scale.float().max():.4e}] unique={torch.unique(l2_weight_scale.view(torch.uint8)).numel()}") + print(f"[WT-XFORM] L2 global_sf: shape={tuple(l2_global_sf.shape)} " + f"range=[{l2_global_sf.min():.4e}, {l2_global_sf.max():.4e}]") + + # Block scales stay as original float8 — NO FOLDING + l1_sf_out = l1_weight_scale.contiguous() + l2_sf_out = l2_weight_scale.contiguous() # CUTLASS B is declared ColumnMajor — it expects (K, N) in memory. # Checkpoint weights are (N, K_half) row-major, so we transpose to (K_half, N) - # which is column-major (N, K_half). This is a one-time cost at load time. l1_weight_out = l1_weight.transpose(-2, -1).contiguous() l2_weight_out = l2_weight.transpose(-2, -1).contiguous() @@ -159,4 +105,4 @@ def transform_nvfp4_weights_for_mega_moe( l1_sf_out = l1_sf_out.transpose(-2, -1).contiguous() l2_sf_out = l2_sf_out.transpose(-2, -1).contiguous() - return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out) + return (l1_weight_out, l1_sf_out, l1_global_sf), (l2_weight_out, l2_sf_out, l2_global_sf) diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 93692955..74a39f86 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -346,8 +346,8 @@ class DeepseekV4MegaMoEExperts(nn.Module): ) set_weight_attrs(self.w2_input_scale, weight_attrs) - self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None - self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None + self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None + self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None # Register in the static forward context so the custom-op wrapper # can look up this module by name from within a torch.compile graph. @@ -437,13 +437,15 @@ class DeepseekV4MegaMoEExperts(nn.Module): from nvfp4_megamoe_kernel import transform_nvfp4_weights_for_mega_moe # === Native NVFP4 path === - # The DeepGEMM nvfp4 mega_moe kernel consumes NVFP4 directly: - # - E2M1 packed uint8 (same as checkpoint) + # The CUTLASS nvfp4 mega_moe kernel consumes NVFP4 directly: + # - E2M1 packed int8 (same as checkpoint) # - UE4M3 block scales (float8_e4m3fn), group_size=16 - # - float32 global scale folded into block scales - # No conversion to MXFP4. Experts stay NVFP4. + # - float32 global scales returned SEPARATELY (NOT folded into float8) + # Previous versions folded global scales into block scales via float8 + # round-trip, which caused ~25% precision loss. Now, global scales + # are applied as per-expert GEMM alpha in float32 (exact). - # Fold global scales into block scales and transform for the kernel + # Transform weights — returns (w, sf, global_sf) tuples self._transformed_l1_weights, self._transformed_l2_weights = ( transform_nvfp4_weights_for_mega_moe( (self.w13_weight.data.contiguous(),