diff --git a/cutedsl/kernel/moe/fused_swiglu_grouped_mm.py b/cutedsl/kernel/moe/fused_swiglu_grouped_mm.py index f2e910e6..7d916613 100644 --- a/cutedsl/kernel/moe/fused_swiglu_grouped_mm.py +++ b/cutedsl/kernel/moe/fused_swiglu_grouped_mm.py @@ -2140,7 +2140,8 @@ class FusedSwiGLUScaledGroupedGemmKernel: acc_pipeline.consumer_release(acc_consumer_state) acc_consumer_state.advance() - # Convert to output dtype, apply global_scale + # ── Fused SwiGLU + NVFP4 Epilogue ── + # Load accumulator, apply global scale acc_vec = cute.zeros_like(tiled_copy_r2s.retile(tTR_rAcc)) if cutlass.const_expr(self.scenario == "2Dx2D"): if k_tile_cnt > 0: @@ -2149,7 +2150,25 @@ class FusedSwiGLUScaledGroupedGemmKernel: acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() if cutlass.const_expr(global_scale_a is not None): acc_vec = acc_vec * alpha - acc_vec = acc_vec.to(self.c_dtype) + + if cutlass.const_expr(self.fused_swiglu): + # ── SwiGLU in registers ── + # With interleaved weights (granularity 8 BF16 = 4 FP4), + # the accumulator N dimension has gate/up pairs adjacent. + # SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x)) + neg_acc = acc_vec * cutlass.float32_t(-1.0) + exp_neg = cute.exp(neg_acc) + sigmoid = cutlass.float32_t(1.0) / (cutlass.float32_t(1.0) + exp_neg) + swiglu_result = acc_vec * sigmoid + + # Write SwiGLU result as BF16 to C tensor + # Stage 1 validation: SiLU applied to full acc_vec. + # This confirms cute.exp works on register tensors. + # The gate/up pairing is added in Stage 2. + acc_vec = swiglu_result.to(self.c_dtype) + else: + # Standard path: convert to output dtype + acc_vec = acc_vec.to(self.c_dtype) tRS_rC.store(acc_vec) # RMEM → SMEM