diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 50d55d5c..08eb88b4 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -161,12 +161,11 @@ def nvfp4_mega_moe_l1( def nvfp4_mega_moe_l2( - x_fp4, # (num_slots, INTER//2) int8 packed E2M1 + x_fp4, # (num_slots, INTER//2) int8 packed E2M1 — already slot-major x_sf, # (num_slots, sf_k_groups) float8_e4m3fn l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs - slot_token, # (num_slots,) int64 — token index per slot l2_global_sf, # (E_per_rank,) float32 — weight global scales alpha=1.0, # fp32 scalar from stage_activation global scale ): @@ -502,7 +501,7 @@ def nvfp4_mega_moe_full( # Step 5: L2 GEMM — slot-based, per-expert alpha l2_slots = nvfp4_mega_moe_l2( l1_fp4, l1_sf_out, l2_w, l2_sf, - slot_expert_local, slot_token, + slot_expert_local, l2_global_sf=l2_global_sf, alpha=l2_alpha, ) # (num_slots, HIDDEN) bfloat16