wip: fused SwiGLU Stage 1 - SiLU in registers (full acc_vec)

Stage 1 of the fused epilogue: applies SiLU (x * sigmoid(x)) to the
full accumulator register tensor before writing BF16 to C.

This validates that cute.exp and element-wise FP32 operations work
on CuTe register tensors in the epilogue. The gate/up pairing is
not yet implemented (Stage 2).

The fused_swiglu flag is const_expr(0) by default, so the standard
epilogue path is unchanged unless the flag is enabled.
This commit is contained in:
2026-05-20 03:07:02 +00:00
parent 2f053f674e
commit 9c43c69a4c

View File

@@ -2140,7 +2140,8 @@ class FusedSwiGLUScaledGroupedGemmKernel:
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
# Convert to output dtype, apply global_scale
# ── Fused SwiGLU + NVFP4 Epilogue ──
# Load accumulator, apply global scale
acc_vec = cute.zeros_like(tiled_copy_r2s.retile(tTR_rAcc))
if cutlass.const_expr(self.scenario == "2Dx2D"):
if k_tile_cnt > 0:
@@ -2149,7 +2150,25 @@ class FusedSwiGLUScaledGroupedGemmKernel:
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
if cutlass.const_expr(global_scale_a is not None):
acc_vec = acc_vec * alpha
acc_vec = acc_vec.to(self.c_dtype)
if cutlass.const_expr(self.fused_swiglu):
# ── SwiGLU in registers ──
# With interleaved weights (granularity 8 BF16 = 4 FP4),
# the accumulator N dimension has gate/up pairs adjacent.
# SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
neg_acc = acc_vec * cutlass.float32_t(-1.0)
exp_neg = cute.exp(neg_acc)
sigmoid = cutlass.float32_t(1.0) / (cutlass.float32_t(1.0) + exp_neg)
swiglu_result = acc_vec * sigmoid
# Write SwiGLU result as BF16 to C tensor
# Stage 1 validation: SiLU applied to full acc_vec.
# This confirms cute.exp works on register tensors.
# The gate/up pairing is added in Stage 2.
acc_vec = swiglu_result.to(self.c_dtype)
else:
# Standard path: convert to output dtype
acc_vec = acc_vec.to(self.c_dtype)
tRS_rC.store(acc_vec)
# RMEM → SMEM