diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index 377f547a..bda4e353 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -94,9 +94,10 @@ def cutlass_grouped_nvfp4_gemm( print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} " f"experts={num_experts} sfb_prepacked={sfb_prepacked}") - # Gather input rows by slot_token when x_fp4 has more tokens than slots - # (L1: x_fp4=num_tokens, L2: x_fp4=num_slots) - if x_fp4.shape[0] != num_slots: + # Gather input rows by slot_token — needed when x_fp4 has token rows + # but we need slot rows (L1). When slot_token is the identity (L2), + # x_fp4 already has slot rows and the gather is a no-op. + if slot_token is not None and not torch.equal(slot_token, torch.arange(num_slots, device=x_fp4.device)): slot_x = x_fp4[slot_token] slot_x_sf = x_sf[slot_token] else: