fix: use slot_token identity check instead of shape heuristic for gather
Shape-based check (x_fp4.shape[0] != num_slots) silently fails when num_tokens == num_slots in L1 (topk=1). Now checks if slot_token is the identity mapping — only gathers when slot ordering differs from token ordering.
This commit is contained in:
@@ -94,9 +94,10 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} "
|
||||
f"experts={num_experts} sfb_prepacked={sfb_prepacked}")
|
||||
|
||||
# Gather input rows by slot_token when x_fp4 has more tokens than slots
|
||||
# (L1: x_fp4=num_tokens, L2: x_fp4=num_slots)
|
||||
if x_fp4.shape[0] != num_slots:
|
||||
# Gather input rows by slot_token — needed when x_fp4 has token rows
|
||||
# but we need slot rows (L1). When slot_token is the identity (L2),
|
||||
# x_fp4 already has slot rows and the gather is a no-op.
|
||||
if slot_token is not None and not torch.equal(slot_token, torch.arange(num_slots, device=x_fp4.device)):
|
||||
slot_x = x_fp4[slot_token]
|
||||
slot_x_sf = x_sf[slot_token]
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user