diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index a0665816..377f547a 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -54,24 +54,28 @@ def prepack_sfb(SFB, M, N, K): def cutlass_grouped_nvfp4_gemm( - x_fp4, # (num_tokens, K_half) int8 packed E2M1 - x_sf, # (num_tokens, sf_k) float8_e4m3fn block scales + x_fp4, # (num_slots_or_tokens, K_half) int8 packed E2M1 + x_sf, # (num_slots_or_tokens, sf_k) float8_e4m3fn block scales weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major — or prepacked (E_per_rank, sfb_size) if sfb_prepacked=True - topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs + slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs + slot_token=None, # (num_slots,) int64 — per-slot token indices (default: arange) alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale) sfb_prepacked=False, # True if weight_sf is already prepacked into CUTLASS layout ): """Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4. - Returns slot-based output: one row per (token, topk) slot routed to a local - expert. No routing weights applied — caller handles that at the final scatter. + Takes 1D per-slot expert IDs and token indices (pre-built by caller). + Returns slot-based output: one row per (token, topk) slot. + + For L1: x_fp4 has num_tokens rows, slot_token maps slots→rows. + For L2: x_fp4 has num_slots rows, slot_token is just arange(num_slots). Returns: slot_out: (num_slots, N) bfloat16 — per-slot GEMM results slot_token: (num_slots,) int64 — token index for each slot """ - num_slots = x_fp4.shape[0] + num_slots = slot_expert_ids.shape[0] K_half = x_fp4.shape[1] K = K_half * 2 N = weights.shape[2] @@ -82,30 +86,27 @@ def cutlass_grouped_nvfp4_gemm( slot_token_out = torch.empty(0, dtype=torch.int64, device=x_fp4.device) return slot_out, slot_token_out - # topk_ids is either: - # 2D (num_tokens, num_topk) from L1 — build slot mapping - # 1D (num_slots,) from L2 — already per-slot expert IDs - if topk_ids.dim() == 2: - num_tokens = topk_ids.shape[0] - local_mask = (topk_ids >= 0) & (topk_ids < num_experts) - slot_token, slot_k = local_mask.nonzero(as_tuple=True) - slot_expert = topk_ids[slot_token, slot_k] - else: - # 1D per-slot expert IDs — slot_token is just arange - slot_expert = topk_ids + # Use provided slot_token or default to identity mapping + if slot_token is None: slot_token = torch.arange(num_slots, device=x_fp4.device) if MEGA_MOE_DEBUG: print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} " f"experts={num_experts} sfb_prepacked={sfb_prepacked}") - slot_x = x_fp4 - slot_x_sf = x_sf + # Gather input rows by slot_token when x_fp4 has more tokens than slots + # (L1: x_fp4=num_tokens, L2: x_fp4=num_slots) + if x_fp4.shape[0] != num_slots: + slot_x = x_fp4[slot_token] + slot_x_sf = x_sf[slot_token] + else: + slot_x = x_fp4 + slot_x_sf = x_sf slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device) for e in range(num_experts): - expert_slots = (slot_expert == e) + expert_slots = (slot_expert_ids == e) if not expert_slots.any(): continue diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index bf021ff1..bbd3a27d 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -126,21 +126,23 @@ def nvfp4_mega_moe_l1( x_sf, # (num_tokens, sf_k_groups) float8_e4m3fn l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major — or prepacked - topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs + slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs + slot_token, # (num_slots,) int64 — token index per slot alpha=1.0, # fp32 scalar from stage_activation global scale sfb_prepacked=False, # True if l1_scales is prepacked CUTLASS layout ): """L1 GEMM: gate_up_proj — slot-based, no routing weights. - Returns (slot_out, slot_token) where each slot is one (token, topk) pair. - Caller applies SiLU+Mul per-slot, then L2, then final scatter with weights. + Takes pre-built slot mapping (slot_expert_ids, slot_token) from the outer + routing logic. Returns (slot_out, slot_token) where each slot is one + (token, topk) pair. """ K_half = x_fp4.shape[1] K = K_half * 2 N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144 if MEGA_MOE_DEBUG: - print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} native=1") + print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} slots={slot_expert_ids.shape[0]}") x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf if not sfb_prepacked: @@ -151,7 +153,8 @@ def nvfp4_mega_moe_l1( slot_out, slot_token = cutlass_grouped_nvfp4_gemm( x_fp4, x_sf_fp8, l1_weights, w_sf_fp8, - topk_ids, + slot_expert_ids, # 1D per-slot expert IDs + slot_token, # 1D per-slot token indices alpha=alpha, sfb_prepacked=sfb_prepacked, ) @@ -369,7 +372,7 @@ def nvfp4_mega_moe_full( # Step 2: L1 GEMM — slot-based, no routing weights, prepacked SFB l1_slots, _ = nvfp4_mega_moe_l1( x_fp4, x_sf, l1_w, l1_sf_prepacked, - topk_ids_local, + slot_expert_local, slot_token, alpha=l1_alpha, sfb_prepacked=True, ) # (num_slots, 2*INTER) bfloat16