From f2cacfc2f2006ea107f3448226468c3e77d6b3d3 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 08:51:23 +0000 Subject: [PATCH] fix the L2 path and the clamping math --- .../cutlass_nvfp4_gemm/kernel.py | 93 ++++---- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 221 +++++++----------- 2 files changed, 139 insertions(+), 175 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index f1262d77..a9a64ee1 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -36,46 +36,58 @@ def cutlass_grouped_nvfp4_gemm( x_sf, # (num_tokens, sf_k) float8_e4m3fn block scales weights, # (E_per_rank, K_half, N) int8 packed E2M1, column-major for CUTLASS weight_sf, # (E_per_rank, sf_k, N) float8_e4m3fn, column-major for CUTLASS - topk_ids, # (num_tokens, NUM_TOPK) int32 - topk_weights, # (num_tokens, NUM_TOPK) float32 + topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale) ): """Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4. - - For each expert, gather the tokens routed to it, run the block-scaled GEMM, - then scatter results back with routing weights. + + Returns slot-based output: one row per (token, topk) slot routed to a local + expert. No routing weights applied — caller handles that at the final scatter. + + Returns: + slot_out: (num_slots, N) bfloat16 — per-slot GEMM results + slot_token: (num_slots,) int64 — token index for each slot """ num_tokens = x_fp4.shape[0] K_half = x_fp4.shape[1] - K = K_half * 2 # Actual K dimension (2 FP4 per byte) - # Weights are (E, K_half, N) column-major (transposed at load time for CUTLASS ColumnMajor B) - N = weights.shape[2] # Output dimension + K = K_half * 2 + N = weights.shape[2] num_experts = weights.shape[0] num_topk = topk_ids.shape[1] - + + # Build slot mapping: which (token, topk) pairs land on local experts? + local_mask = (topk_ids >= 0) & (topk_ids < num_experts) # (num_tokens, num_topk) + slot_token, slot_k = local_mask.nonzero(as_tuple=True) # (num_slots,) + slot_expert = topk_ids[slot_token, slot_k] # (num_slots,) local expert id + + num_slots = slot_token.shape[0] + if MEGA_MOE_DEBUG: print(f"[cutlass_grouped_gemm] tokens={num_tokens} K={K} N={N} " - f"experts={num_experts} topk={num_topk}") - - output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=x_fp4.device) - + f"experts={num_experts} topk={num_topk} slots={num_slots}") + + if num_slots == 0: + slot_out = torch.empty(0, N, dtype=torch.bfloat16, device=x_fp4.device) + return slot_out, slot_token + + # Gather activations for all slots + slot_x = x_fp4[slot_token] # (num_slots, K_half) + slot_x_sf = x_sf[slot_token] # (num_slots, sf_k) + + slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device) + for e in range(num_experts): - # Find tokens routed to this expert - expert_mask = (topk_ids == e) # (num_tokens, num_topk) - token_indices = expert_mask.any(dim=1).nonzero(as_tuple=True)[0] - - if token_indices.numel() == 0: + expert_slots = (slot_expert == e) + if not expert_slots.any(): continue - - # Gather tokens for this expert - expert_x = x_fp4[token_indices] # (num_expert_tokens, K_half) - expert_x_sf = x_sf[token_indices] # (num_expert_tokens, sf_k) - expert_w = weights[e] # (K_half, N) column-major for CUTLASS - expert_w_sf = weight_sf[e] # (sf_k, N) column-major for CUTLASS - - M_expert = token_indices.shape[0] - - # DEBUG: verify data going into GEMM + + e_idx = expert_slots.nonzero(as_tuple=True)[0] + expert_x = slot_x[e_idx] + expert_x_sf = slot_x_sf[e_idx] + expert_w = weights[e] + expert_w_sf = weight_sf[e] + M_expert = e_idx.shape[0] + if e < 3 and M_expert > 0: print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} " f"w shape={expert_w.shape} w_sf shape={expert_w_sf.shape} " @@ -83,19 +95,16 @@ def cutlass_grouped_nvfp4_gemm( f"w_sf range=[{expert_w_sf.to(torch.float32).min().item():.4e}, " f"{expert_w_sf.to(torch.float32).max().item():.4e}] " f"w_sf nonzero_frac={(expert_w_sf.view(torch.uint8) != 0).float().mean().item():.4f}") - - # Run CUTLASS NVFP4 block-scaled GEMM + expert_out = cutlass_nvfp4_blockscaled_gemm( expert_x, expert_x_sf, - expert_w, expert_w_sf, # Pass directly — already (N, K_half) and (N, sf_k) + expert_w, expert_w_sf, M_expert, N, K, alpha=alpha, - ) # (M_expert, N) bfloat16 - - # Check for CUDA errors after each expert GEMM + ) + torch.cuda.current_stream().synchronize() - - # Hard-fail on NaN/Inf — silent skip was hiding bugs + if torch.isnan(expert_out).any() or torch.isinf(expert_out).any(): raise RuntimeError( f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. " @@ -108,11 +117,7 @@ def cutlass_grouped_nvfp4_gemm( f"x_sf nan_frac={torch.isnan(expert_x_sf.to(torch.float32)).float().mean().item():.4f}, " f"w_sf nan_frac={torch.isnan(expert_w_sf.to(torch.float32)).float().mean().item():.4f}" ) - - # Scatter back with routing weights - for t_idx, token_idx in enumerate(token_indices): - for k_idx in range(num_topk): - if topk_ids[token_idx, k_idx] == e: - output[token_idx] += topk_weights[token_idx, k_idx] * expert_out[t_idx] - - return output + + slot_out[e_idx] = expert_out + + return slot_out, slot_token diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 7d2cc026..a6c8b9a8 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -5,8 +5,9 @@ This is the main kernel that replaces fp8_nvfp4_mega_moe from DeepGEMM. Architecture: - L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with UE4M3 scales) -- SiLU+Mul activation +- SiLU+Mul activation (per-slot, BEFORE combining expert paths) - L2 GEMM: down_proj (FP4 x FP4 → BF16 with UE4M3 scales) +- Routing weights applied ONCE at final scatter - NVLink cross-rank sync handled by caller (not this kernel) - Expert parallel: each rank handles NUM_EXPERTS/8 experts @@ -90,82 +91,75 @@ MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0")) def nvfp4_mega_moe_l1( x_fp4, # (num_tokens, K//2) int8 packed E2M1 - x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3 + x_sf, # (num_tokens, sf_k_groups) float8_e4m3fn l1_weights, # (E_per_rank, K//2, 2*INTER) int8, column-major for CUTLASS l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major - topk_ids, # (num_tokens, NUM_TOPK) int32 - topk_weights, # (num_tokens, NUM_TOPK) float32 - num_experts_per_rank, + topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs alpha=1.0, # fp32 scalar from stage_activation global scale ): - """L1 GEMM: gate_up_proj — Native NVFP4 block-scaled MMA. + """L1 GEMM: gate_up_proj — slot-based, no routing weights. - Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1 - with UE4M3 block-16 scaling in tensor cores. - - Falls back to dequantize+BF16 if native path unavailable. + Returns (slot_out, slot_token) where each slot is one (token, topk) pair. + Caller applies SiLU+Mul per-slot, then L2, then final scatter with weights. """ - num_tokens = x_fp4.shape[0] K_half = x_fp4.shape[1] - K = K_half * 2 # HIDDEN = 7168 - N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144 (column-major: shape is E, K_half, N) + K = K_half * 2 + N = l1_weights.shape[2] # 2 * INTERMEDIATE = 6144 if MEGA_MOE_DEBUG: - print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} " - f"experts={num_experts_per_rank} native=1") + print(f"[nvfp4_moe_l1] tokens={x_fp4.shape[0]} K={K} N={N} native=1") - # DEBUG: verify weight shapes after transpose - if MEGA_MOE_DEBUG: - print(f"[L1-WT] l1_w shape={l1_weights.shape} l1_sf shape={l1_scales.shape} w_sf dtype={l1_scales.dtype}") - - # Unpack uint32 packed UE4M3 scales to float8_e4m3fn x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales - output = cutlass_grouped_nvfp4_gemm( + slot_out, slot_token = cutlass_grouped_nvfp4_gemm( x_fp4, x_sf_fp8, l1_weights, w_sf_fp8, - topk_ids, topk_weights, + topk_ids, alpha=alpha, ) - print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}") - return output # (num_tokens, 6144) bfloat16 + print(f"[L1-GEMM-OUT] slots={slot_out.shape[0]} N={N} amax={slot_out.abs().max().item():.4e} mean={slot_out.float().mean().item():.4e}") + return slot_out, slot_token def nvfp4_mega_moe_l2( - x_fp4, # (num_tokens, INTER//2) int8 packed E2M1 - x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3 + x_fp4, # (num_slots, INTER//2) int8 packed E2M1 + x_sf, # (num_slots, sf_k_groups) float8_e4m3fn l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major - topk_ids, # (num_tokens, NUM_TOPK) int32 - topk_weights, # (num_tokens, NUM_TOPK) float32 - num_experts_per_rank, + topk_ids, # (num_tokens, NUM_TOPK) int32 — local expert IDs (for slot mapping) + slot_token, # (num_slots,) int64 — token index per slot (from L1) alpha=1.0, # fp32 scalar from stage_activation global scale ): - """L2 GEMM: down_proj — Native NVFP4 block-scaled MMA. + """L2 GEMM: down_proj — slot-based, no routing weights. - Same pipeline as L1 using native mxf8f6f4.block_scale MMA. + Reuses the same slot mapping from L1 (same slot_token indices). + topk_ids is passed to rebuild the slot→expert mapping. """ - num_tokens = x_fp4.shape[0] K_half = x_fp4.shape[1] - K = K_half * 2 # INTERMEDIATE = 3072 - N = l2_weights.shape[2] # HIDDEN = 7168 (column-major: shape is E, K_half, N) + K = K_half * 2 + N = l2_weights.shape[2] if MEGA_MOE_DEBUG: - print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} " - f"experts={num_experts_per_rank} native=1") + print(f"[nvfp4_moe_l2] slots={x_fp4.shape[0]} K={K} N={N} native=1") - # Unpack uint32 packed UE4M3 scales to float8_e4m3fn x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales - output = cutlass_grouped_nvfp4_gemm( + # Build local expert IDs per slot (same mapping as L1) + num_topk = topk_ids.shape[1] + num_experts = l2_weights.shape[0] + local_mask = (topk_ids >= 0) & (topk_ids < num_experts) + _, slot_k = local_mask.nonzero(as_tuple=True) + slot_expert_ids = topk_ids[slot_token, slot_k] # (num_slots,) + + slot_out, _ = cutlass_grouped_nvfp4_gemm( x_fp4, x_sf_fp8, l2_weights, w_sf_fp8, - topk_ids, topk_weights, + slot_expert_ids, # per-slot expert IDs alpha=alpha, ) - return output # (num_tokens, 7168) bfloat16 + return slot_out # (num_slots, HIDDEN) bfloat16 # E2M1 (FP4) representable magnitudes: {0, 0.5, 1, 1.5, 2, 3, 4, 6} @@ -191,37 +185,29 @@ def _quantize_to_e2m1(x_f32): x_blocks = x_f32.reshape(*batch, N // 16, 16) # Per-block absmax determines the scale - block_max = x_blocks.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8, max=448.0) + block_max = x_blocks.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) # Scale so that the max maps to 6.0 (largest E2M1 magnitude) - # Dequant: x_reconstructed = x_e2m1 * scale, where scale = block_max / 6.0 - scale_f32 = block_max / 6.0 + scale_f32 = (block_max / 6.0).clamp(min=1e-8, max=448.0) x_scaled = x_blocks / scale_f32.clamp(min=1e-8) # Find nearest E2M1 magnitude for each value - signs = torch.sign(x_scaled) # +1, -1, or 0 - abs_scaled = x_scaled.abs() # 0..6 range + signs = torch.sign(x_scaled) + abs_scaled = x_scaled.abs() - # Nearest E2M1 magnitude: find closest in {0, 0.5, 1, 1.5, 2, 3, 4, 6} mags = _E2M1_MAGNITUDES.to(device=abs_scaled.device) - # Distance from each value to each magnitude - dists = (abs_scaled.unsqueeze(-1) - mags).abs() # (..., 16, 8) - idx = dists.argmin(dim=-1) # (..., 16) — index into E2M1 magnitudes + dists = (abs_scaled.unsqueeze(-1) - mags).abs() + idx = dists.argmin(dim=-1) - # Clamp to valid range (safety) idx = idx.clamp(0, 7).to(torch.uint8) - # Build 4-bit sign-magnitude nibble: bit3=sign, bits2:0=magnitude index - sign_bit = (signs < 0).to(torch.uint8) # 1 if negative - nibbles = (sign_bit << 3) | idx # (..., 16) uint8, values 0..15 + sign_bit = (signs < 0).to(torch.uint8) + nibbles = (sign_bit << 3) | idx - # Pack 2 nibbles per byte: low nibble = even index, high nibble = odd index nibbles = nibbles.reshape(*batch, N // 2, 2) - packed = (nibbles[..., 1] << 4) | nibbles[..., 0] # (..., N//2) uint8 + packed = (nibbles[..., 1] << 4) | nibbles[..., 0] - # Scale factors: what the GEMM needs to reconstruct the original values - # dequant = e2m1_magnitude * scale, so scale = block_max / 6.0 - sf = scale_f32.squeeze(-1).to(torch.float8_e4m3fn) # (..., N//16) + sf = scale_f32.squeeze(-1).to(torch.float8_e4m3fn) return packed.to(torch.int8), sf @@ -231,32 +217,18 @@ def stage_activation(x_bf16): Two-level quantization matching the NVFP4 weight format: 1. Per-tensor global scale: amax / (6.0 * 448.0) - Normalizes the activation so that block scales fit in UE4M3 range. 2. Per-block (16 values) absmax scaling on the normalized values - Snap to nearest E2M1 representable value: {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} - Pack as 4-bit sign-magnitude nibbles (bit3=sign, bits2:0=mag index) - Block scale = block_max / 6.0 stored as UE4M3 (float8_e4m3fn) Returns (x_fp4, x_sf, input_global_scale) where: x_fp4: packed E2M1 nibbles x_sf: UE4M3 block scales (NOT folded with global scale) input_global_scale: fp32 per-tensor scale, applied as GEMM alpha - - The GEMM applies global scale via alpha: D = alpha * (A_sf * A_fp4) @ (B_sf * B_fp4) - This avoids fp32→UE4M3 round-trip from folding, preserving precision. """ x_f32 = x_bf16.float() - # Per-tensor global scale (same role as weight_scale_2) - # NVFP4 spec: global_scale = amax / (6.0 * 448.0) - # This ensures the largest block scale after normalization is ~448.0, - # which fits exactly in UE4M3 max (448.0 for E4M3). x_amax = x_f32.abs().amax().to(torch.float32).clamp(min=1e-8) input_global_scale = x_amax / (6.0 * 448.0) - # Normalize by global scale before block quantization. - # After this, values are in a range where block_max / 6.0 ≤ 448.0, - # so block scales fit in UE4M3 without saturation. x_normalized = x_f32 / input_global_scale x_fp4, x_sf = _quantize_to_e2m1(x_normalized) @@ -274,22 +246,14 @@ def nvfp4_mega_moe_full( ): """Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe. - API matches the DeepGEMM fp8_nvfp4_mega_moe call signature used in - the vLLM deepseek_v4.py patch: - - fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer, - activation_clamp=..., fast_math=...) - - Pipeline: - 1. Read staged activation from symm_buffer (already quantized by staging kernel) - 2. L1 GEMM: gate_up_proj (native NVFP4 block-scaled MMA) - 3. SiLU + Mul (activation) - 4. Quantize L1 output → FP4 + UE4M3 scales - 5. L2 GEMM: down_proj (native NVFP4 block-scaled MMA) - 6. Write to y (caller handles cross-rank all-reduce) - - Uses tcgen05.mma.kind::mxf8f6f4.block_scale for native E2M1×E2M1 - with UE4M3 block-16 scaling in Blackwell tensor cores. + Slot-based pipeline (routing weights applied ONCE at final scatter): + 1. Read staged activation from symm_buffer + 2. L1 GEMM → slot output (num_slots, 2*INTER) — NO routing weights + 3. SiLU + Mul PER SLOT (nonlinearity before combining expert paths) + 4. Quantize activated slots → FP4 + 5. L2 GEMM → slot output (num_slots, HIDDEN) — NO routing weights + 6. Final scatter: y.index_add_(0, slot_token, slot_weight * l2_slots) + Single routing weight application. """ num_tokens = y.shape[0] device = y.device @@ -318,87 +282,82 @@ def nvfp4_mega_moe_full( # Step 1: Read staged activation from symm_buffer x_fp4 = symm_buffer.x[:num_tokens] x_sf = symm_buffer.x_sf[:num_tokens] - l1_global_scale = symm_buffer.input_global_scale # fp32, from stage_activation + l1_global_scale = symm_buffer.input_global_scale topk_ids = symm_buffer.topk_idx[:num_tokens] topk_weights = symm_buffer.topk_weights[:num_tokens] - # ALWAYS-ON debug: alpha and scale ranges _x_sf_f32 = x_sf.to(torch.float32) _igs = l1_global_scale if isinstance(l1_global_scale, float) else l1_global_scale.item() if hasattr(l1_global_scale, 'item') else float(l1_global_scale) if MEGA_MOE_DEBUG: - print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}] x_fp4_absmax={x_fp4.view(torch.int8).abs().max().item()}") + print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}]") - # Convert global expert IDs to local expert IDs. - # vLLM's symm_buffer stores global IDs (0..383) but our weight tensors - # are indexed by local ID (0..47). Each rank handles a contiguous chunk: - # rank r gets experts [r*E_per_rank, (r+1)*E_per_rank). + # Convert global expert IDs to local expert IDs num_experts_per_rank = l1_w.shape[0] experts_start_idx = symm_buffer.experts_start_idx topk_ids_local = topk_ids - experts_start_idx - # Routing diagnostic (ungated — needed to diagnose zero-GEMM on specific ranks) + # Build slot mapping for this rank local_topk = (topk_ids >= experts_start_idx) & (topk_ids < experts_start_idx + num_experts_per_rank) + slot_token, slot_k = local_topk.nonzero(as_tuple=True) + slot_expert_local = topk_ids_local[slot_token, slot_k] + slot_weight = topk_weights[slot_token, slot_k] + num_slots = slot_token.shape[0] + tokens_routed_locally = local_topk.any(dim=-1).sum().item() - print(f"[ROUTING] tokens_routed_local={tokens_routed_locally}/{topk_ids.shape[0]} " - f"unique_local_experts={local_topk.long().sum().item()}") + print(f"[ROUTING] tokens_routed_local={tokens_routed_locally}/{num_tokens} " + f"num_slots={num_slots}") if MEGA_MOE_DEBUG: print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} " - f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} " + f"topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} " f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} " - f"l1_w={l1_w.shape} l2_w={l2_w.shape}") + f"slots={num_slots}") - # NaN-trace: check activation scales at L1 input - if MEGA_MOE_DEBUG: - x_sf_f32 = x_sf.to(torch.float32) - print(f"[L1-in] x_sf nan={torch.isnan(x_sf_f32).any().item()} " - f"inf={torch.isinf(x_sf_f32).any().item()} " - f"min={x_sf_f32.min().item():.4e} max={x_sf_f32.max().item():.4e}") + # Handle no local slots + if num_slots == 0: + y.zero_() + return - # Step 2: L1 GEMM (native NVFP4 block-scaled MMA) - l1_output = nvfp4_mega_moe_l1( + # Step 2: L1 GEMM — slot-based, no routing weights + l1_slots, _ = nvfp4_mega_moe_l1( x_fp4, x_sf, l1_w, l1_sf, - topk_ids_local, topk_weights, num_experts_per_rank, + topk_ids_local, alpha=l1_global_scale, - ) + ) # (num_slots, 2*INTER) bfloat16 - # NaN-trace: check L1 output if MEGA_MOE_DEBUG: - print(f"[L1-out] nan={torch.isnan(l1_output).any().item()} " - f"inf={torch.isinf(l1_output).any().item()} " - f"abs_max={l1_output.abs().max().item():.4e}") + print(f"[L1-out] nan={torch.isnan(l1_slots).any().item()} " + f"abs_max={l1_slots.abs().max().item():.4e}") - # Step 3: SiLU + Mul - gate, up = l1_output.chunk(2, dim=-1) + # Step 3: SiLU + Mul PER SLOT — nonlinearity before combining paths + gate, up = l1_slots.chunk(2, dim=-1) activated = torch.nn.functional.silu(gate) * up if activation_clamp is not None: activated = activated.clamp(max=activation_clamp) - # NaN-trace: check SiLU output if MEGA_MOE_DEBUG: print(f"[silu] nan={torch.isnan(activated).any().item()} " f"abs_max={activated.abs().max().item():.4e}") - # Step 4: Quantize L1 output → FP4 + # Step 4: Quantize activated slots → FP4 l1_fp4, l1_sf_out, l2_global_scale = stage_activation(activated) - # ALWAYS-ON debug: L2 alpha and scale ranges - _l1sf_f32 = l1_sf_out.to(torch.float32) - _l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item() if hasattr(l2_global_scale, 'item') else float(l2_global_scale) if MEGA_MOE_DEBUG: - print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}] activated amax={activated.abs().max().item():.4e}") + _l1sf_f32 = l1_sf_out.to(torch.float32) + _l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item() + print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]") - # Step 5: L2 GEMM (native NVFP4 block-scaled MMA) - l2_output = nvfp4_mega_moe_l2( + # Step 5: L2 GEMM — slot-based, no routing weights + l2_slots = nvfp4_mega_moe_l2( l1_fp4, l1_sf_out, l2_w, l2_sf, - topk_ids_local, topk_weights, num_experts_per_rank, + topk_ids_local, slot_token, alpha=l2_global_scale, - ) + ) # (num_slots, HIDDEN) bfloat16 - # NaN-trace: check L2 output if MEGA_MOE_DEBUG: - print(f"[L2-out] nan={torch.isnan(l2_output).any().item()} " - f"abs_max={l2_output.abs().max().item():.4e}") + print(f"[L2-out] nan={torch.isnan(l2_slots).any().item()} " + f"abs_max={l2_slots.abs().max().item():.4e}") - # Step 6: Write to output (caller handles cross-rank all-reduce) - y.copy_(l2_output) + # Step 6: Final scatter — routing weights applied ONCE + y.zero_() + y.index_add_(0, slot_token, slot_weight[:, None] * l2_slots)