fix the L2 path and the clamping math

This commit is contained in:
2026-05-15 08:51:23 +00:00
parent d22dae2df3
commit f2cacfc2f2
2 changed files with 139 additions and 175 deletions

View File

@@ -36,46 +36,58 @@ def cutlass_grouped_nvfp4_gemm(
x_sf, # (num_tokens, sf_k) float8_e4m3fn block scales
weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS
weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major for CUTLASS
topk_ids, # (num_tokens, NUM_TOPK) int32
topk_weights, # (num_tokens, NUM_TOPK) float32
topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs
alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale)
):
"""Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4.
For each expert, gather the tokens routed to it, run the block-scaled GEMM,
then scatter results back with routing weights.
Returns slot-based output: one row per (token, topk) slot routed to a local
expert. No routing weights applied — caller handles that at the final scatter.
Returns:
slot_out: (num_slots, N) bfloat16 — per-slot GEMM results
slot_token: (num_slots,) int64 — token index for each slot
"""
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # Actual K dimension (2 FP4 per byte)
# Weights are (E, K_half, N) column-major (transposed at load time for CUTLASS ColumnMajor B)
N = weights.shape[2] # Output dimension
K = K_half * 2
N = weights.shape[2]
num_experts = weights.shape[0]
num_topk = topk_ids.shape[1]
# Build slot mapping: which (token, topk) pairs land on local experts?
local_mask = (topk_ids >= 0) & (topk_ids < num_experts) # (num_tokens, num_topk)
slot_token, slot_k = local_mask.nonzero(as_tuple=True) # (num_slots,)
slot_expert = topk_ids[slot_token, slot_k] # (num_slots,) local expert id
num_slots = slot_token.shape[0]
if MEGA_MOE_DEBUG:
print(f"[cutlass_grouped_gemm] tokens={num_tokens} K={K} N={N} "
f"experts={num_experts} topk={num_topk}")
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=x_fp4.device)
f"experts={num_experts} topk={num_topk} slots={num_slots}")
if num_slots == 0:
slot_out = torch.empty(0, N, dtype=torch.bfloat16, device=x_fp4.device)
return slot_out, slot_token
# Gather activations for all slots
slot_x = x_fp4[slot_token] # (num_slots, K_half)
slot_x_sf = x_sf[slot_token] # (num_slots, sf_k)
slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device)
for e in range(num_experts):
# Find tokens routed to this expert
expert_mask = (topk_ids == e) # (num_tokens, num_topk)
token_indices = expert_mask.any(dim=1).nonzero(as_tuple=True)[0]
if token_indices.numel() == 0:
expert_slots = (slot_expert == e)
if not expert_slots.any():
continue
# Gather tokens for this expert
expert_x = x_fp4[token_indices] # (num_expert_tokens, K_half)
expert_x_sf = x_sf[token_indices] # (num_expert_tokens, sf_k)
expert_w = weights[e] # (K_half, N) column-major for CUTLASS
expert_w_sf = weight_sf[e] # (sf_k, N) column-major for CUTLASS
M_expert = token_indices.shape[0]
# DEBUG: verify data going into GEMM
e_idx = expert_slots.nonzero(as_tuple=True)[0]
expert_x = slot_x[e_idx]
expert_x_sf = slot_x_sf[e_idx]
expert_w = weights[e]
expert_w_sf = weight_sf[e]
M_expert = e_idx.shape[0]
if e < 3 and M_expert > 0:
print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} "
f"w shape={expert_w.shape} w_sf shape={expert_w_sf.shape} "
@@ -83,19 +95,16 @@ def cutlass_grouped_nvfp4_gemm(
f"w_sf range=[{expert_w_sf.to(torch.float32).min().item():.4e}, "
f"{expert_w_sf.to(torch.float32).max().item():.4e}] "
f"w_sf nonzero_frac={(expert_w_sf.view(torch.uint8) != 0).float().mean().item():.4f}")
# Run CUTLASS NVFP4 block-scaled GEMM
expert_out = cutlass_nvfp4_blockscaled_gemm(
expert_x, expert_x_sf,
expert_w, expert_w_sf, # Pass directly — already (N, K_half) and (N, sf_k)
expert_w, expert_w_sf,
M_expert, N, K,
alpha=alpha,
) # (M_expert, N) bfloat16
# Check for CUDA errors after each expert GEMM
)
torch.cuda.current_stream().synchronize()
# Hard-fail on NaN/Inf — silent skip was hiding bugs
if torch.isnan(expert_out).any() or torch.isinf(expert_out).any():
raise RuntimeError(
f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. "
@@ -108,11 +117,7 @@ def cutlass_grouped_nvfp4_gemm(
f"x_sf nan_frac={torch.isnan(expert_x_sf.to(torch.float32)).float().mean().item():.4f}, "
f"w_sf nan_frac={torch.isnan(expert_w_sf.to(torch.float32)).float().mean().item():.4f}"
)
# Scatter back with routing weights
for t_idx, token_idx in enumerate(token_indices):
for k_idx in range(num_topk):
if topk_ids[token_idx, k_idx] == e:
output[token_idx] += topk_weights[token_idx, k_idx] * expert_out[t_idx]
return output
slot_out[e_idx] = expert_out
return slot_out, slot_token

View File

@@ -5,8 +5,9 @@ This is the main kernel that replaces fp8_nvfp4_mega_moe from DeepGEMM.
Architecture:
- L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with UE4M3 scales)
- SiLU+Mul activation
- SiLU+Mul activation (per-slot, BEFORE combining expert paths)
- L2 GEMM: down_proj (FP4 x FP4 → BF16 with UE4M3 scales)
- Routing weights applied ONCE at final scatter
- NVLink cross-rank sync handled by caller (not this kernel)
- Expert parallel: each rank handles NUM_EXPERTS/8 experts
@@ -90,82 +91,75 @@ MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
def nvfp4_mega_moe_l1(
x_fp4, # (num_tokens, K//2) int8 packed E2M1
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
x_sf, # (num_tokens, sf_k_groups) float8_e4m3fn
l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major
topk_ids, # (num_tokens, NUM_TOPK) int32
topk_weights, # (num_tokens, NUM_TOPK) float32
num_experts_per_rank,
topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs
alpha=1.0, # fp32 scalar from stage_activation global scale
):
"""L1 GEMM: gate_up_proj — Native NVFP4 block-scaled MMA.
"""L1 GEMM: gate_up_proj — slot-based, no routing weights.
Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1
with UE4M3 block-16 scaling in tensor cores.
Falls back to dequantize+BF16 if native path unavailable.
Returns (slot_out, slot_token) where each slot is one (token, topk) pair.
Caller applies SiLU+Mul per-slot, then L2, then final scatter with weights.
"""
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # HIDDEN = 7168
N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144 (column-major: shape is E, K_half, N)
K = K_half * 2
N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144
if MEGA_MOE_DEBUG:
print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} "
f"experts={num_experts_per_rank} native=1")
print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} native=1")
# DEBUG: verify weight shapes after transpose
if MEGA_MOE_DEBUG:
print(f"[L1-WT] l1_w shape={l1_weights.shape} l1_sf shape={l1_scales.shape} w_sf dtype={l1_scales.dtype}")
# Unpack uint32 packed UE4M3 scales to float8_e4m3fn
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
output = cutlass_grouped_nvfp4_gemm(
slot_out, slot_token = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l1_weights, w_sf_fp8,
topk_ids, topk_weights,
topk_ids,
alpha=alpha,
)
print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}")
return output # (num_tokens, 6144) bfloat16
print(f"[L1-GEMM-OUT] slots={slot_out.shape[0]} N={N} amax={slot_out.abs().max().item():.4e} mean={slot_out.float().mean().item():.4e}")
return slot_out, slot_token
def nvfp4_mega_moe_l2(
x_fp4, # (num_tokens, INTER//2) int8 packed E2M1
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
x_fp4, # (num_slots, INTER//2) int8 packed E2M1
x_sf, # (num_slots, sf_k_groups) float8_e4m3fn
l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major
topk_ids, # (num_tokens, NUM_TOPK) int32
topk_weights, # (num_tokens, NUM_TOPK) float32
num_experts_per_rank,
topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs (for slot mapping)
slot_token, # (num_slots,) int64 — token index per slot (from L1)
alpha=1.0, # fp32 scalar from stage_activation global scale
):
"""L2 GEMM: down_proj — Native NVFP4 block-scaled MMA.
"""L2 GEMM: down_proj — slot-based, no routing weights.
Same pipeline as L1 using native mxf8f6f4.block_scale MMA.
Reuses the same slot mapping from L1 (same slot_token indices).
topk_ids is passed to rebuild the slot→expert mapping.
"""
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # INTERMEDIATE = 3072
N = l2_weights.shape[2] # HIDDEN = 7168 (column-major: shape is E, K_half, N)
K = K_half * 2
N = l2_weights.shape[2]
if MEGA_MOE_DEBUG:
print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} "
f"experts={num_experts_per_rank} native=1")
print(f"[nvfp4_moe_l2] slots={x_fp4.shape[0]} K={K} N={N} native=1")
# Unpack uint32 packed UE4M3 scales to float8_e4m3fn
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
output = cutlass_grouped_nvfp4_gemm(
# Build local expert IDs per slot (same mapping as L1)
num_topk = topk_ids.shape[1]
num_experts = l2_weights.shape[0]
local_mask = (topk_ids >= 0) & (topk_ids < num_experts)
_, slot_k = local_mask.nonzero(as_tuple=True)
slot_expert_ids = topk_ids[slot_token, slot_k] # (num_slots,)
slot_out, _ = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l2_weights, w_sf_fp8,
topk_ids, topk_weights,
slot_expert_ids, # per-slot expert IDs
alpha=alpha,
)
return output # (num_tokens, 7168) bfloat16
return slot_out # (num_slots, HIDDEN) bfloat16
# E2M1 (FP4) representable magnitudes: {0, 0.5, 1, 1.5, 2, 3, 4, 6}
@@ -191,37 +185,29 @@ def _quantize_to_e2m1(x_f32):
x_blocks = x_f32.reshape(*batch, N // 16, 16)
# Per-block absmax determines the scale
block_max = x_blocks.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8, max=448.0)
block_max = x_blocks.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
# Scale so that the max maps to 6.0 (largest E2M1 magnitude)
# Dequant: x_reconstructed = x_e2m1 * scale, where scale = block_max / 6.0
scale_f32 = block_max / 6.0
scale_f32 = (block_max / 6.0).clamp(min=1e-8, max=448.0)
x_scaled = x_blocks / scale_f32.clamp(min=1e-8)
# Find nearest E2M1 magnitude for each value
signs = torch.sign(x_scaled) # +1, -1, or 0
abs_scaled = x_scaled.abs() # 0..6 range
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs()
# Nearest E2M1 magnitude: find closest in {0, 0.5, 1, 1.5, 2, 3, 4, 6}
mags = _E2M1_MAGNITUDES.to(device=abs_scaled.device)
# Distance from each value to each magnitude
dists = (abs_scaled.unsqueeze(-1) - mags).abs() # (..., 16, 8)
idx = dists.argmin(dim=-1) # (..., 16) — index into E2M1 magnitudes
dists = (abs_scaled.unsqueeze(-1) - mags).abs()
idx = dists.argmin(dim=-1)
# Clamp to valid range (safety)
idx = idx.clamp(0, 7).to(torch.uint8)
# Build 4-bit sign-magnitude nibble: bit3=sign, bits2:0=magnitude index
sign_bit = (signs < 0).to(torch.uint8) # 1 if negative
nibbles = (sign_bit << 3) | idx # (..., 16) uint8, values 0..15
sign_bit = (signs < 0).to(torch.uint8)
nibbles = (sign_bit << 3) | idx
# Pack 2 nibbles per byte: low nibble = even index, high nibble = odd index
nibbles = nibbles.reshape(*batch, N // 2, 2)
packed = (nibbles[..., 1] << 4) | nibbles[..., 0] # (..., N//2) uint8
packed = (nibbles[..., 1] << 4) | nibbles[..., 0]
# Scale factors: what the GEMM needs to reconstruct the original values
# dequant = e2m1_magnitude * scale, so scale = block_max / 6.0
sf = scale_f32.squeeze(-1).to(torch.float8_e4m3fn) # (..., N//16)
sf = scale_f32.squeeze(-1).to(torch.float8_e4m3fn)
return packed.to(torch.int8), sf
@@ -231,32 +217,18 @@ def stage_activation(x_bf16):
Two-level quantization matching the NVFP4 weight format:
1. Per-tensor global scale: amax / (6.0 * 448.0)
Normalizes the activation so that block scales fit in UE4M3 range.
2. Per-block (16 values) absmax scaling on the normalized values
Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6}
Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index)
Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn)
Returns (x_fp4, x_sf, input_global_scale) where:
x_fp4: packed E2M1 nibbles
x_sf: UE4M3 block scales (NOT folded with global scale)
input_global_scale: fp32 per-tensor scale, applied as GEMM alpha
The GEMM applies global scale via alpha: D = alpha * (A_sf * A_fp4) @ (B_sf * B_fp4)
This avoids fp32→UE4M3 round-trip from folding, preserving precision.
"""
x_f32 = x_bf16.float()
# Per-tensor global scale (same role as weight_scale_2)
# NVFP4 spec: global_scale = amax / (6.0 * 448.0)
# This ensures the largest block scale after normalization is ~448.0,
# which fits exactly in UE4M3 max (448.0 for E4M3).
x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8)
input_global_scale = x_amax / (6.0 * 448.0)
# Normalize by global scale before block quantization.
# After this, values are in a range where block_max / 6.0 ≤ 448.0,
# so block scales fit in UE4M3 without saturation.
x_normalized = x_f32 / input_global_scale
x_fp4, x_sf = _quantize_to_e2m1(x_normalized)
@@ -274,22 +246,14 @@ def nvfp4_mega_moe_full(
):
"""Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe.
API matches the DeepGEMM fp8_nvfp4_mega_moe call signature used in
the vLLM deepseek_v4.py patch:
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer,
activation_clamp=..., fast_math=...)
Pipeline:
1. Read staged activation from symm_buffer (already quantized by staging kernel)
2. L1 GEMM: gate_up_proj (native NVFP4 block-scaled MMA)
3. SiLU + Mul (activation)
4. Quantize L1 output → FP4 + UE4M3 scales
5. L2 GEMM: down_proj (native NVFP4 block-scaled MMA)
6. Write to y (caller handles cross-rank all-reduce)
Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1
with UE4M3 block-16 scaling in Blackwell tensor cores.
Slot-based pipeline (routing weights applied ONCE at final scatter):
1. Read staged activation from symm_buffer
2. L1 GEMM → slot output (num_slots, 2*INTER) — NO routing weights
3. SiLU + Mul PER SLOT (nonlinearity before combining expert paths)
4. Quantize activated slots → FP4
5. L2 GEMM → slot output (num_slots, HIDDEN) — NO routing weights
6. Final scatter: y.index_add_(0, slot_token, slot_weight * l2_slots)
Single routing weight application.
"""
num_tokens = y.shape[0]
device = y.device
@@ -318,87 +282,82 @@ def nvfp4_mega_moe_full(
# Step 1: Read staged activation from symm_buffer
x_fp4 = symm_buffer.x[:num_tokens]
x_sf = symm_buffer.x_sf[:num_tokens]
l1_global_scale = symm_buffer.input_global_scale # fp32, from stage_activation
l1_global_scale = symm_buffer.input_global_scale
topk_ids = symm_buffer.topk_idx[:num_tokens]
topk_weights = symm_buffer.topk_weights[:num_tokens]
# ALWAYS-ON debug: alpha and scale ranges
_x_sf_f32 = x_sf.to(torch.float32)
_igs = l1_global_scale if isinstance(l1_global_scale, float) else l1_global_scale.item() if hasattr(l1_global_scale, 'item') else float(l1_global_scale)
if MEGA_MOE_DEBUG:
print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}] x_fp4_absmax={x_fp4.view(torch.int8).abs().max().item()}")
print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}]")
# Convert global expert IDs to local expert IDs.
# vLLM's symm_buffer stores global IDs (0..383) but our weight tensors
# are indexed by local ID (0..47). Each rank handles a contiguous chunk:
# rank r gets experts [r*E_per_rank, (r+1)*E_per_rank).
# Convert global expert IDs to local expert IDs
num_experts_per_rank = l1_w.shape[0]
experts_start_idx = symm_buffer.experts_start_idx
topk_ids_local = topk_ids - experts_start_idx
# Routing diagnostic (ungated — needed to diagnose zero-GEMM on specific ranks)
# Build slot mapping for this rank
local_topk = (topk_ids >= experts_start_idx) & (topk_ids < experts_start_idx + num_experts_per_rank)
slot_token, slot_k = local_topk.nonzero(as_tuple=True)
slot_expert_local = topk_ids_local[slot_token, slot_k]
slot_weight = topk_weights[slot_token, slot_k]
num_slots = slot_token.shape[0]
tokens_routed_locally = local_topk.any(dim=-1).sum().item()
print(f"[ROUTING] tokens_routed_local={tokens_routed_locally}/{topk_ids.shape[0]} "
f"unique_local_experts={local_topk.long().sum().item()}")
print(f"[ROUTING] tokens_routed_local={tokens_routed_locally}/{num_tokens} "
f"num_slots={num_slots}")
if MEGA_MOE_DEBUG:
print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} "
f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "
f"topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "
f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} "
f"l1_w={l1_w.shape} l2_w={l2_w.shape}")
f"slots={num_slots}")
# NaN-trace: check activation scales at L1 input
if MEGA_MOE_DEBUG:
x_sf_f32 = x_sf.to(torch.float32)
print(f"[L1-in] x_sf nan={torch.isnan(x_sf_f32).any().item()} "
f"inf={torch.isinf(x_sf_f32).any().item()} "
f"min={x_sf_f32.min().item():.4e} max={x_sf_f32.max().item():.4e}")
# Handle no local slots
if num_slots == 0:
y.zero_()
return
# Step 2: L1 GEMM (native NVFP4 block-scaled MMA)
l1_output = nvfp4_mega_moe_l1(
# Step 2: L1 GEMM — slot-based, no routing weights
l1_slots, _ = nvfp4_mega_moe_l1(
x_fp4, x_sf, l1_w, l1_sf,
topk_ids_local, topk_weights, num_experts_per_rank,
topk_ids_local,
alpha=l1_global_scale,
)
) # (num_slots, 2*INTER) bfloat16
# NaN-trace: check L1 output
if MEGA_MOE_DEBUG:
print(f"[L1-out] nan={torch.isnan(l1_output).any().item()} "
f"inf={torch.isinf(l1_output).any().item()} "
f"abs_max={l1_output.abs().max().item():.4e}")
print(f"[L1-out] nan={torch.isnan(l1_slots).any().item()} "
f"abs_max={l1_slots.abs().max().item():.4e}")
# Step 3: SiLU + Mul
gate, up = l1_output.chunk(2, dim=-1)
# Step 3: SiLU + Mul PER SLOT — nonlinearity before combining paths
gate, up = l1_slots.chunk(2, dim=-1)
activated = torch.nn.functional.silu(gate) * up
if activation_clamp is not None:
activated = activated.clamp(max=activation_clamp)
# NaN-trace: check SiLU output
if MEGA_MOE_DEBUG:
print(f"[silu] nan={torch.isnan(activated).any().item()} "
f"abs_max={activated.abs().max().item():.4e}")
# Step 4: Quantize L1 output → FP4
# Step 4: Quantize activated slots → FP4
l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated)
# ALWAYS-ON debug: L2 alpha and scale ranges
_l1sf_f32 = l1_sf_out.to(torch.float32)
_l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item() if hasattr(l2_global_scale, 'item') else float(l2_global_scale)
if MEGA_MOE_DEBUG:
print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}] activated amax={activated.abs().max().item():.4e}")
_l1sf_f32 = l1_sf_out.to(torch.float32)
_l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item()
print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]")
# Step 5: L2 GEMM (native NVFP4 block-scaled MMA)
l2_output = nvfp4_mega_moe_l2(
# Step 5: L2 GEMM — slot-based, no routing weights
l2_slots = nvfp4_mega_moe_l2(
l1_fp4, l1_sf_out, l2_w, l2_sf,
topk_ids_local, topk_weights, num_experts_per_rank,
topk_ids_local, slot_token,
alpha=l2_global_scale,
)
) # (num_slots, HIDDEN) bfloat16
# NaN-trace: check L2 output
if MEGA_MOE_DEBUG:
print(f"[L2-out] nan={torch.isnan(l2_output).any().item()} "
f"abs_max={l2_output.abs().max().item():.4e}")
print(f"[L2-out] nan={torch.isnan(l2_slots).any().item()} "
f"abs_max={l2_slots.abs().max().item():.4e}")
# Step 6: Write to output (caller handles cross-rank all-reduce)
y.copy_(l2_output)
# Step 6: Final scatter — routing weights applied ONCE
y.zero_()
y.index_add_(0, slot_token, slot_weight[:, None] * l2_slots)