refactor: unify L1/L2 to use 1D slot_expert_ids consistently

Both L1 and L2 now pass pre-built 1D slot_expert_ids and slot_token to
cutlass_grouped_nvfp4_gemm instead of the 2D topk_ids.

The 2D path was broken for expert parallelism — local_mask matched ALL
local experts, producing mismatched slot_token/slot_k lengths that caused
vectorized_gather_kernel index out of bounds.

cutlass_grouped_nvfp4_gemm now:
- Takes 1D slot_expert_ids + optional slot_token
- Gathers x_fp4 by slot_token when needed (L1: tokens→slots)
- Skips gather when x_fp4 already has num_slots rows (L2)
This commit is contained in:
2026-05-15 09:56:46 +00:00
parent 093babadc6
commit ded80be133
2 changed files with 30 additions and 26 deletions

View File

@@ -54,24 +54,28 @@ def prepack_sfb(SFB, M, N, K):
def cutlass_grouped_nvfp4_gemm(
x_fp4, # (num_tokens, K_half) int8 packed E2M1
x_sf, # (num_tokens, sf_k) float8_e4m3fn block scales
x_fp4, # (num_slots_or_tokens, K_half) int8 packed E2M1
x_sf, # (num_slots_or_tokens, sf_k) float8_e4m3fn block scales
weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS
weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major — or prepacked (E_per_rank, sfb_size) if sfb_prepacked=True
topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
slot_token=None, # (num_slots,) int64 — per-slot token indices (default: arange)
alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale)
sfb_prepacked=False, # True if weight_sf is already prepacked into CUTLASS layout
):
"""Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4.
Returns slot-based output: one row per (token, topk) slot routed to a local
expert. No routing weights applied — caller handles that at the final scatter.
Takes 1D per-slot expert IDs and token indices (pre-built by caller).
Returns slot-based output: one row per (token, topk) slot.
For L1: x_fp4 has num_tokens rows, slot_token maps slots→rows.
For L2: x_fp4 has num_slots rows, slot_token is just arange(num_slots).
Returns:
slot_out: (num_slots, N) bfloat16 — per-slot GEMM results
slot_token: (num_slots,) int64 — token index for each slot
"""
num_slots = x_fp4.shape[0]
num_slots = slot_expert_ids.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2
N = weights.shape[2]
@@ -82,30 +86,27 @@ def cutlass_grouped_nvfp4_gemm(
slot_token_out = torch.empty(0, dtype=torch.int64, device=x_fp4.device)
return slot_out, slot_token_out
# topk_ids is either:
# 2D (num_tokens, num_topk) from L1 — build slot mapping
# 1D (num_slots,) from L2 — already per-slot expert IDs
if topk_ids.dim() == 2:
num_tokens = topk_ids.shape[0]
local_mask = (topk_ids >= 0) & (topk_ids < num_experts)
slot_token, slot_k = local_mask.nonzero(as_tuple=True)
slot_expert = topk_ids[slot_token, slot_k]
else:
# 1D per-slot expert IDs — slot_token is just arange
slot_expert = topk_ids
# Use provided slot_token or default to identity mapping
if slot_token is None:
slot_token = torch.arange(num_slots, device=x_fp4.device)
if MEGA_MOE_DEBUG:
print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} "
f"experts={num_experts} sfb_prepacked={sfb_prepacked}")
slot_x = x_fp4
slot_x_sf = x_sf
# Gather input rows by slot_token when x_fp4 has more tokens than slots
# (L1: x_fp4=num_tokens, L2: x_fp4=num_slots)
if x_fp4.shape[0] != num_slots:
slot_x = x_fp4[slot_token]
slot_x_sf = x_sf[slot_token]
else:
slot_x = x_fp4
slot_x_sf = x_sf
slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device)
for e in range(num_experts):
expert_slots = (slot_expert == e)
expert_slots = (slot_expert_ids == e)
if not expert_slots.any():
continue

View File

@@ -126,21 +126,23 @@ def nvfp4_mega_moe_l1(
x_sf, # (num_tokens, sf_k_groups) float8_e4m3fn
l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major — or prepacked
topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
slot_token, # (num_slots,) int64 — token index per slot
alpha=1.0, # fp32 scalar from stage_activation global scale
sfb_prepacked=False, # True if l1_scales is prepacked CUTLASS layout
):
"""L1 GEMM: gate_up_proj — slot-based, no routing weights.
Returns (slot_out, slot_token) where each slot is one (token, topk) pair.
Caller applies SiLU+Mul per-slot, then L2, then final scatter with weights.
Takes pre-built slot mapping (slot_expert_ids, slot_token) from the outer
routing logic. Returns (slot_out, slot_token) where each slot is one
(token, topk) pair.
"""
K_half = x_fp4.shape[1]
K = K_half * 2
N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144
if MEGA_MOE_DEBUG:
print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} native=1")
print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} slots={slot_expert_ids.shape[0]}")
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
if not sfb_prepacked:
@@ -151,7 +153,8 @@ def nvfp4_mega_moe_l1(
slot_out, slot_token = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l1_weights, w_sf_fp8,
topk_ids,
slot_expert_ids, # 1D per-slot expert IDs
slot_token, # 1D per-slot token indices
alpha=alpha,
sfb_prepacked=sfb_prepacked,
)
@@ -369,7 +372,7 @@ def nvfp4_mega_moe_full(
# Step 2: L1 GEMM — slot-based, no routing weights, prepacked SFB
l1_slots, _ = nvfp4_mega_moe_l1(
x_fp4, x_sf, l1_w, l1_sf_prepacked,
topk_ids_local,
slot_expert_local, slot_token,
alpha=l1_alpha,
sfb_prepacked=True,
) # (num_slots, 2*INTER) bfloat16