refactor: unify L1/L2 to use 1D slot_expert_ids consistently
Both L1 and L2 now pass pre-built 1D slot_expert_ids and slot_token to cutlass_grouped_nvfp4_gemm instead of the 2D topk_ids. The 2D path was broken for expert parallelism — local_mask matched ALL local experts, producing mismatched slot_token/slot_k lengths that caused vectorized_gather_kernel index out of bounds. cutlass_grouped_nvfp4_gemm now: - Takes 1D slot_expert_ids + optional slot_token - Gathers x_fp4 by slot_token when needed (L1: tokens→slots) - Skips gather when x_fp4 already has num_slots rows (L2)
This commit is contained in:
@@ -54,24 +54,28 @@ def prepack_sfb(SFB, M, N, K):
|
||||
|
||||
|
||||
def cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, # (num_tokens, K_half) int8 packed E2M1
|
||||
x_sf, # (num_tokens, sf_k) float8_e4m3fn block scales
|
||||
x_fp4, # (num_slots_or_tokens, K_half) int8 packed E2M1
|
||||
x_sf, # (num_slots_or_tokens, sf_k) float8_e4m3fn block scales
|
||||
weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS
|
||||
weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major — or prepacked (E_per_rank, sfb_size) if sfb_prepacked=True
|
||||
topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs
|
||||
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
|
||||
slot_token=None, # (num_slots,) int64 — per-slot token indices (default: arange)
|
||||
alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale)
|
||||
sfb_prepacked=False, # True if weight_sf is already prepacked into CUTLASS layout
|
||||
):
|
||||
"""Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4.
|
||||
|
||||
Returns slot-based output: one row per (token, topk) slot routed to a local
|
||||
expert. No routing weights applied — caller handles that at the final scatter.
|
||||
Takes 1D per-slot expert IDs and token indices (pre-built by caller).
|
||||
Returns slot-based output: one row per (token, topk) slot.
|
||||
|
||||
For L1: x_fp4 has num_tokens rows, slot_token maps slots→rows.
|
||||
For L2: x_fp4 has num_slots rows, slot_token is just arange(num_slots).
|
||||
|
||||
Returns:
|
||||
slot_out: (num_slots, N) bfloat16 — per-slot GEMM results
|
||||
slot_token: (num_slots,) int64 — token index for each slot
|
||||
"""
|
||||
num_slots = x_fp4.shape[0]
|
||||
num_slots = slot_expert_ids.shape[0]
|
||||
K_half = x_fp4.shape[1]
|
||||
K = K_half * 2
|
||||
N = weights.shape[2]
|
||||
@@ -82,30 +86,27 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
slot_token_out = torch.empty(0, dtype=torch.int64, device=x_fp4.device)
|
||||
return slot_out, slot_token_out
|
||||
|
||||
# topk_ids is either:
|
||||
# 2D (num_tokens, num_topk) from L1 — build slot mapping
|
||||
# 1D (num_slots,) from L2 — already per-slot expert IDs
|
||||
if topk_ids.dim() == 2:
|
||||
num_tokens = topk_ids.shape[0]
|
||||
local_mask = (topk_ids >= 0) & (topk_ids < num_experts)
|
||||
slot_token, slot_k = local_mask.nonzero(as_tuple=True)
|
||||
slot_expert = topk_ids[slot_token, slot_k]
|
||||
else:
|
||||
# 1D per-slot expert IDs — slot_token is just arange
|
||||
slot_expert = topk_ids
|
||||
# Use provided slot_token or default to identity mapping
|
||||
if slot_token is None:
|
||||
slot_token = torch.arange(num_slots, device=x_fp4.device)
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} "
|
||||
f"experts={num_experts} sfb_prepacked={sfb_prepacked}")
|
||||
|
||||
slot_x = x_fp4
|
||||
slot_x_sf = x_sf
|
||||
# Gather input rows by slot_token when x_fp4 has more tokens than slots
|
||||
# (L1: x_fp4=num_tokens, L2: x_fp4=num_slots)
|
||||
if x_fp4.shape[0] != num_slots:
|
||||
slot_x = x_fp4[slot_token]
|
||||
slot_x_sf = x_sf[slot_token]
|
||||
else:
|
||||
slot_x = x_fp4
|
||||
slot_x_sf = x_sf
|
||||
|
||||
slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device)
|
||||
|
||||
for e in range(num_experts):
|
||||
expert_slots = (slot_expert == e)
|
||||
expert_slots = (slot_expert_ids == e)
|
||||
if not expert_slots.any():
|
||||
continue
|
||||
|
||||
|
||||
@@ -126,21 +126,23 @@ def nvfp4_mega_moe_l1(
|
||||
x_sf, # (num_tokens, sf_k_groups) float8_e4m3fn
|
||||
l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS
|
||||
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major — or prepacked
|
||||
topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs
|
||||
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
|
||||
slot_token, # (num_slots,) int64 — token index per slot
|
||||
alpha=1.0, # fp32 scalar from stage_activation global scale
|
||||
sfb_prepacked=False, # True if l1_scales is prepacked CUTLASS layout
|
||||
):
|
||||
"""L1 GEMM: gate_up_proj — slot-based, no routing weights.
|
||||
|
||||
Returns (slot_out, slot_token) where each slot is one (token, topk) pair.
|
||||
Caller applies SiLU+Mul per-slot, then L2, then final scatter with weights.
|
||||
Takes pre-built slot mapping (slot_expert_ids, slot_token) from the outer
|
||||
routing logic. Returns (slot_out, slot_token) where each slot is one
|
||||
(token, topk) pair.
|
||||
"""
|
||||
K_half = x_fp4.shape[1]
|
||||
K = K_half * 2
|
||||
N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} native=1")
|
||||
print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} slots={slot_expert_ids.shape[0]}")
|
||||
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||||
if not sfb_prepacked:
|
||||
@@ -151,7 +153,8 @@ def nvfp4_mega_moe_l1(
|
||||
slot_out, slot_token = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
l1_weights, w_sf_fp8,
|
||||
topk_ids,
|
||||
slot_expert_ids, # 1D per-slot expert IDs
|
||||
slot_token, # 1D per-slot token indices
|
||||
alpha=alpha,
|
||||
sfb_prepacked=sfb_prepacked,
|
||||
)
|
||||
@@ -369,7 +372,7 @@ def nvfp4_mega_moe_full(
|
||||
# Step 2: L1 GEMM — slot-based, no routing weights, prepacked SFB
|
||||
l1_slots, _ = nvfp4_mega_moe_l1(
|
||||
x_fp4, x_sf, l1_w, l1_sf_prepacked,
|
||||
topk_ids_local,
|
||||
slot_expert_local, slot_token,
|
||||
alpha=l1_alpha,
|
||||
sfb_prepacked=True,
|
||||
) # (num_slots, 2*INTER) bfloat16
|
||||
|
||||
Reference in New Issue
Block a user