From 7a1538d0c84d2df7c130cad23ebbe14ca05ca312 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 10:06:07 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20gather=20on=20slot=5Ftoken=20presence,?= =?UTF-8?q?=20add=20shape=20asserts=20L1=E2=86=92L2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove torch.equal heuristic — just gather when slot_token is provided - Add asserts for slot mapping shapes (ndim, numel == num_slots) - Add post-L1 and pre-L2 shape asserts (l1_slots, activated, l1_fp4, l1_sf_out) --- .../cutlass_nvfp4_gemm/kernel.py | 7 +++---- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index bda4e353..663bb222 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -94,10 +94,9 @@ def cutlass_grouped_nvfp4_gemm( print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} " f"experts={num_experts} sfb_prepacked={sfb_prepacked}") - # Gather input rows by slot_token — needed when x_fp4 has token rows - # but we need slot rows (L1). When slot_token is the identity (L2), - # x_fp4 already has slot rows and the gather is a no-op. - if slot_token is not None and not torch.equal(slot_token, torch.arange(num_slots, device=x_fp4.device)): + # Gather input rows by slot_token when provided (L1: tokens→slots). + # L2 doesn't pass slot_token, so no gather needed. + if slot_token is not None: slot_x = x_fp4[slot_token] slot_x_sf = x_sf[slot_token] else: diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 08692619..ce1d3f33 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -369,6 +369,14 @@ def nvfp4_mega_moe_full( # Ensure alpha is a plain Python float (C extension can't handle torch scalars) l1_alpha = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale + # Shape consistency asserts — catch mismatched slot mappings early + assert slot_expert_local.ndim == 1 + assert slot_token.ndim == 1 + assert slot_weight.ndim == 1 + assert slot_expert_local.numel() == num_slots + assert slot_token.numel() == num_slots + assert slot_weight.numel() == num_slots + # Prepack SFB weight scales into CUTLASS layout (lazy, once per layer) l1_N = l1_w.shape[2] l1_K = l1_w.shape[1] * 2 @@ -386,6 +394,9 @@ def nvfp4_mega_moe_full( sfb_prepacked=True, ) # (num_slots, 2*INTER) bfloat16 + # Post-L1 shape asserts + assert l1_slots.shape[0] == num_slots + if MEGA_MOE_DEBUG: print(f"[L1-out] nan={torch.isnan(l1_slots).any().item()} " f"abs_max={l1_slots.abs().max().item():.4e}") @@ -402,6 +413,11 @@ def nvfp4_mega_moe_full( # Step 4: Quantize activated slots → FP4 l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated) + + # Pre-L2 shape asserts + assert activated.shape[0] == num_slots + assert l1_fp4.shape[0] == num_slots + assert l1_sf_out.shape[0] == num_slots l2_alpha = float(l2_global_scale) if not isinstance(l2_global_scale, float) else l2_global_scale if MEGA_MOE_DEBUG: