cleanup: remove unused slot_token from nvfp4_moe_l2
L2 input is already slot-major, so slot_token was accepted but never passed to the GEMM. Made it explicit by removing the parameter.
This commit is contained in:
@@ -161,12 +161,11 @@ def nvfp4_mega_moe_l1(
|
||||
|
||||
|
||||
def nvfp4_mega_moe_l2(
|
||||
x_fp4, # (num_slots, INTER//2) int8 packed E2M1
|
||||
x_fp4, # (num_slots, INTER//2) int8 packed E2M1 — already slot-major
|
||||
x_sf, # (num_slots, sf_k_groups) float8_e4m3fn
|
||||
l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS
|
||||
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major
|
||||
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
|
||||
slot_token, # (num_slots,) int64 — token index per slot
|
||||
l2_global_sf, # (E_per_rank,) float32 — weight global scales
|
||||
alpha=1.0, # fp32 scalar from stage_activation global scale
|
||||
):
|
||||
@@ -502,7 +501,7 @@ def nvfp4_mega_moe_full(
|
||||
# Step 5: L2 GEMM — slot-based, per-expert alpha
|
||||
l2_slots = nvfp4_mega_moe_l2(
|
||||
l1_fp4, l1_sf_out, l2_w, l2_sf,
|
||||
slot_expert_local, slot_token,
|
||||
slot_expert_local,
|
||||
l2_global_sf=l2_global_sf,
|
||||
alpha=l2_alpha,
|
||||
) # (num_slots, HIDDEN) bfloat16
|
||||
|
||||
Reference in New Issue
Block a user