From b7c7e9fb5041fa4f379dc03a2802e014d243e110 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 10:11:40 +0000 Subject: [PATCH] refactor: clean up slot_token handling in cutlass_grouped_nvfp4_gemm - Split provided_slot_token vs slot_token_out (returned to caller) - No gather when slot_token=None (L2 path), no unnecessary alloc - .contiguous() on gathered tensors for CUTLASS alignment - Return slot_token_out consistently --- .../cutlass_nvfp4_gemm/kernel.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index 663bb222..b7aa774c 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -87,22 +87,21 @@ def cutlass_grouped_nvfp4_gemm( return slot_out, slot_token_out # Use provided slot_token or default to identity mapping - if slot_token is None: - slot_token = torch.arange(num_slots, device=x_fp4.device) + provided_slot_token = slot_token + + if provided_slot_token is None: + slot_token_out = torch.arange(num_slots, device=x_fp4.device) + slot_x = x_fp4 + slot_x_sf = x_sf + else: + slot_token_out = provided_slot_token + slot_x = x_fp4[provided_slot_token].contiguous() + slot_x_sf = x_sf[provided_slot_token].contiguous() if MEGA_MOE_DEBUG: print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} " f"experts={num_experts} sfb_prepacked={sfb_prepacked}") - # Gather input rows by slot_token when provided (L1: tokens→slots). - # L2 doesn't pass slot_token, so no gather needed. - if slot_token is not None: - slot_x = x_fp4[slot_token] - slot_x_sf = x_sf[slot_token] - else: - slot_x = x_fp4 - slot_x_sf = x_sf - slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device) for e in range(num_experts): @@ -137,4 +136,4 @@ def cutlass_grouped_nvfp4_gemm( slot_out[e_idx] = expert_out - return slot_out, slot_token + return slot_out, slot_token_out