From 9c43c69a4c98fa7c3866c50f2f452c019ebe8e79 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 20 May 2026 03:07:02 +0000 Subject: [PATCH] wip: fused SwiGLU Stage 1 - SiLU in registers (full acc_vec) Stage 1 of the fused epilogue: applies SiLU (x * sigmoid(x)) to the full accumulator register tensor before writing BF16 to C. This validates that cute.exp and element-wise FP32 operations work on CuTe register tensors in the epilogue. The gate/up pairing is not yet implemented (Stage 2). The fused_swiglu flag is const_expr(0) by default, so the standard epilogue path is unchanged unless the flag is enabled. --- cutedsl/kernel/moe/fused_swiglu_grouped_mm.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/cutedsl/kernel/moe/fused_swiglu_grouped_mm.py b/cutedsl/kernel/moe/fused_swiglu_grouped_mm.py index f2e910e6..7d916613 100644 --- a/cutedsl/kernel/moe/fused_swiglu_grouped_mm.py +++ b/cutedsl/kernel/moe/fused_swiglu_grouped_mm.py @@ -2140,7 +2140,8 @@ class FusedSwiGLUScaledGroupedGemmKernel: acc_pipeline.consumer_release(acc_consumer_state) acc_consumer_state.advance() - # Convert to output dtype, apply global_scale + # ── Fused SwiGLU + NVFP4 Epilogue ── + # Load accumulator, apply global scale acc_vec = cute.zeros_like(tiled_copy_r2s.retile(tTR_rAcc)) if cutlass.const_expr(self.scenario == "2Dx2D"): if k_tile_cnt > 0: @@ -2149,7 +2150,25 @@ class FusedSwiGLUScaledGroupedGemmKernel: acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() if cutlass.const_expr(global_scale_a is not None): acc_vec = acc_vec * alpha - acc_vec = acc_vec.to(self.c_dtype) + + if cutlass.const_expr(self.fused_swiglu): + # ── SwiGLU in registers ── + # With interleaved weights (granularity 8 BF16 = 4 FP4), + # the accumulator N dimension has gate/up pairs adjacent. + # SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x)) + neg_acc = acc_vec * cutlass.float32_t(-1.0) + exp_neg = cute.exp(neg_acc) + sigmoid = cutlass.float32_t(1.0) / (cutlass.float32_t(1.0) + exp_neg) + swiglu_result = acc_vec * sigmoid + + # Write SwiGLU result as BF16 to C tensor + # Stage 1 validation: SiLU applied to full acc_vec. + # This confirms cute.exp works on register tensors. + # The gate/up pairing is added in Stage 2. + acc_vec = swiglu_result.to(self.c_dtype) + else: + # Standard path: convert to output dtype + acc_vec = acc_vec.to(self.c_dtype) tRS_rC.store(acc_vec) # RMEM → SMEM