cleanup: remove unused slot_token from nvfp4_moe_l2

L2 input is already slot-major, so slot_token was accepted but never
passed to the GEMM. Made it explicit by removing the parameter.
This commit is contained in:
2026-05-15 23:50:39 +00:00
parent 887360281e
commit bb5a1ba4c8

View File

@@ -161,12 +161,11 @@ def nvfp4_mega_moe_l1(
def nvfp4_mega_moe_l2(
x_fp4, # (num_slots, INTER//2) int8 packed E2M1
x_fp4, # (num_slots, INTER//2) int8 packed E2M1 — already slot-major
x_sf, # (num_slots, sf_k_groups) float8_e4m3fn
l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
slot_token, # (num_slots,) int64 — token index per slot
l2_global_sf, # (E_per_rank,) float32 — weight global scales
alpha=1.0, # fp32 scalar from stage_activation global scale
):
@@ -502,7 +501,7 @@ def nvfp4_mega_moe_full(
# Step 5: L2 GEMM — slot-based, per-expert alpha
l2_slots = nvfp4_mega_moe_l2(
l1_fp4, l1_sf_out, l2_w, l2_sf,
slot_expert_local, slot_token,
slot_expert_local,
l2_global_sf=l2_global_sf,
alpha=l2_alpha,
) # (num_slots, HIDDEN) bfloat16