fix: use slot_token identity check instead of shape heuristic for gather

Shape-based check (x_fp4.shape[0] != num_slots) silently fails when
num_tokens == num_slots in L1 (topk=1). Now checks if slot_token is
the identity mapping — only gathers when slot ordering differs from
token ordering.
This commit is contained in:
2026-05-15 10:00:41 +00:00
parent ded80be133
commit 3ba41b9322

View File

@@ -94,9 +94,10 @@ def cutlass_grouped_nvfp4_gemm(
print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} "
f"experts={num_experts} sfb_prepacked={sfb_prepacked}")
# Gather input rows by slot_token when x_fp4 has more tokens than slots
# (L1: x_fp4=num_tokens, L2: x_fp4=num_slots)
if x_fp4.shape[0] != num_slots:
# Gather input rows by slot_token — needed when x_fp4 has token rows
# but we need slot rows (L1). When slot_token is the identity (L2),
# x_fp4 already has slot rows and the gather is a no-op.
if slot_token is not None and not torch.equal(slot_token, torch.arange(num_slots, device=x_fp4.device)):
slot_x = x_fp4[slot_token]
slot_x_sf = x_sf[slot_token]
else: