From 4c06c51ec385ddefcfdadd46e27a2105f3c1283e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 19:28:15 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20moe=5Fpipeline.py=20gate/up=20split=20?= =?UTF-8?q?=E2=80=94=20L1=20output=20is=202*intermediate,=20not=20intermed?= =?UTF-8?q?iate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cutedsl/moe_pipeline.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cutedsl/moe_pipeline.py b/cutedsl/moe_pipeline.py index 49717d39..1dd2dd73 100644 --- a/cutedsl/moe_pipeline.py +++ b/cutedsl/moe_pipeline.py @@ -199,10 +199,11 @@ def run_nvfp4_moe( # ════════════════════════════════════════════════════════════════ # SiLU(gate) * up (BF16 — nonlinear requires BF16) # ════════════════════════════════════════════════════════════════ - intermediate = l1_out.shape[1] - gate = l1_out[:, :intermediate] - up = l1_out[:, intermediate:] - activated = torch.nn.functional.silu(gate) * up # (num_slots, half) BF16 + # L1 output is (tokens, 2*intermediate) — gate and up fused + intermediate_size = l1_out.shape[1] // 2 + gate = l1_out[:, :intermediate_size] + up = l1_out[:, intermediate_size:] + activated = torch.nn.functional.silu(gate) * up # (num_slots, intermediate) BF16 # ════════════════════════════════════════════════════════════════ # L2: down projection (NVFP4 × NVFP4 → BF16)