refactor: clean up slot_token handling in cutlass_grouped_nvfp4_gemm
- Split provided_slot_token vs slot_token_out (returned to caller) - No gather when slot_token=None (L2 path), no unnecessary alloc - .contiguous() on gathered tensors for CUTLASS alignment - Return slot_token_out consistently
This commit is contained in:
@@ -87,22 +87,21 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
return slot_out, slot_token_out
|
||||
|
||||
# Use provided slot_token or default to identity mapping
|
||||
if slot_token is None:
|
||||
slot_token = torch.arange(num_slots, device=x_fp4.device)
|
||||
provided_slot_token = slot_token
|
||||
|
||||
if provided_slot_token is None:
|
||||
slot_token_out = torch.arange(num_slots, device=x_fp4.device)
|
||||
slot_x = x_fp4
|
||||
slot_x_sf = x_sf
|
||||
else:
|
||||
slot_token_out = provided_slot_token
|
||||
slot_x = x_fp4[provided_slot_token].contiguous()
|
||||
slot_x_sf = x_sf[provided_slot_token].contiguous()
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} "
|
||||
f"experts={num_experts} sfb_prepacked={sfb_prepacked}")
|
||||
|
||||
# Gather input rows by slot_token when provided (L1: tokens→slots).
|
||||
# L2 doesn't pass slot_token, so no gather needed.
|
||||
if slot_token is not None:
|
||||
slot_x = x_fp4[slot_token]
|
||||
slot_x_sf = x_sf[slot_token]
|
||||
else:
|
||||
slot_x = x_fp4
|
||||
slot_x_sf = x_sf
|
||||
|
||||
slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device)
|
||||
|
||||
for e in range(num_experts):
|
||||
@@ -137,4 +136,4 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
|
||||
slot_out[e_idx] = expert_out
|
||||
|
||||
return slot_out, slot_token
|
||||
return slot_out, slot_token_out
|
||||
|
||||
Reference in New Issue
Block a user