wip: fused SwiGLU Stage 1 - SiLU in registers (full acc_vec)
Stage 1 of the fused epilogue: applies SiLU (x * sigmoid(x)) to the full accumulator register tensor before writing BF16 to C. This validates that cute.exp and element-wise FP32 operations work on CuTe register tensors in the epilogue. The gate/up pairing is not yet implemented (Stage 2). The fused_swiglu flag is const_expr(0) by default, so the standard epilogue path is unchanged unless the flag is enabled.
This commit is contained in:
@@ -2140,7 +2140,8 @@ class FusedSwiGLUScaledGroupedGemmKernel:
|
||||
acc_pipeline.consumer_release(acc_consumer_state)
|
||||
acc_consumer_state.advance()
|
||||
|
||||
# Convert to output dtype, apply global_scale
|
||||
# ── Fused SwiGLU + NVFP4 Epilogue ──
|
||||
# Load accumulator, apply global scale
|
||||
acc_vec = cute.zeros_like(tiled_copy_r2s.retile(tTR_rAcc))
|
||||
if cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
if k_tile_cnt > 0:
|
||||
@@ -2149,7 +2150,25 @@ class FusedSwiGLUScaledGroupedGemmKernel:
|
||||
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
|
||||
if cutlass.const_expr(global_scale_a is not None):
|
||||
acc_vec = acc_vec * alpha
|
||||
acc_vec = acc_vec.to(self.c_dtype)
|
||||
|
||||
if cutlass.const_expr(self.fused_swiglu):
|
||||
# ── SwiGLU in registers ──
|
||||
# With interleaved weights (granularity 8 BF16 = 4 FP4),
|
||||
# the accumulator N dimension has gate/up pairs adjacent.
|
||||
# SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
|
||||
neg_acc = acc_vec * cutlass.float32_t(-1.0)
|
||||
exp_neg = cute.exp(neg_acc)
|
||||
sigmoid = cutlass.float32_t(1.0) / (cutlass.float32_t(1.0) + exp_neg)
|
||||
swiglu_result = acc_vec * sigmoid
|
||||
|
||||
# Write SwiGLU result as BF16 to C tensor
|
||||
# Stage 1 validation: SiLU applied to full acc_vec.
|
||||
# This confirms cute.exp works on register tensors.
|
||||
# The gate/up pairing is added in Stage 2.
|
||||
acc_vec = swiglu_result.to(self.c_dtype)
|
||||
else:
|
||||
# Standard path: convert to output dtype
|
||||
acc_vec = acc_vec.to(self.c_dtype)
|
||||
tRS_rC.store(acc_vec)
|
||||
|
||||
# RMEM → SMEM
|
||||
|
||||
Reference in New Issue
Block a user