diff --git a/cutedsl/moe_pipeline.py b/cutedsl/moe_pipeline.py index 49717d39..1dd2dd73 100644 --- a/cutedsl/moe_pipeline.py +++ b/cutedsl/moe_pipeline.py @@ -199,10 +199,11 @@ def run_nvfp4_moe( # ════════════════════════════════════════════════════════════════ # SiLU(gate) * up (BF16 — nonlinear requires BF16) # ════════════════════════════════════════════════════════════════ - intermediate = l1_out.shape[1] - gate = l1_out[:, :intermediate] - up = l1_out[:, intermediate:] - activated = torch.nn.functional.silu(gate) * up # (num_slots, half) BF16 + # L1 output is (tokens, 2*intermediate) — gate and up fused + intermediate_size = l1_out.shape[1] // 2 + gate = l1_out[:, :intermediate_size] + up = l1_out[:, intermediate_size:] + activated = torch.nn.functional.silu(gate) * up # (num_slots, intermediate) BF16 # ════════════════════════════════════════════════════════════════ # L2: down projection (NVFP4 × NVFP4 → BF16)