From bb5a1ba4c844f40dd7d6f1a0f40ce8a57af5a4c1 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 23:50:39 +0000 Subject: [PATCH] cleanup: remove unused slot_token from nvfp4_moe_l2 L2 input is already slot-major, so slot_token was accepted but never passed to the GEMM. Made it explicit by removing the parameter. --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 50d55d5c..08eb88b4 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -161,12 +161,11 @@ def nvfp4_mega_moe_l1( def nvfp4_mega_moe_l2( - x_fp4, # (num_slots, INTER//2) int8 packed E2M1 + x_fp4, # (num_slots, INTER//2) int8 packed E2M1 — already slot-major x_sf, # (num_slots, sf_k_groups) float8_e4m3fn l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs - slot_token, # (num_slots,) int64 — token index per slot l2_global_sf, # (E_per_rank,) float32 — weight global scales alpha=1.0, # fp32 scalar from stage_activation global scale ): @@ -502,7 +501,7 @@ def nvfp4_mega_moe_full( # Step 5: L2 GEMM — slot-based, per-expert alpha l2_slots = nvfp4_mega_moe_l2( l1_fp4, l1_sf_out, l2_w, l2_sf, - slot_expert_local, slot_token, + slot_expert_local, l2_global_sf=l2_global_sf, alpha=l2_alpha, ) # (num_slots, HIDDEN) bfloat16