fix: add SwiGLU clamping to fused kernel (paper §4.2.3, CG-1)

The fused SwiGLU kernel stored swiglu_limit but never applied it.
Paper §4.2.3: gate capped at swiglu_limit, linear clamped to [-limit, +limit].
Non-fused reference path already applies clamping correctly.
Fix: add fmin/fmax clamping in FP32 before BF16 conversion.
This commit is contained in:
2026-05-23 06:32:54 +00:00
parent 11c7e2c663
commit 578d186c20

View File

@@ -2187,11 +2187,18 @@ class FusedSwiGLUScaledGroupedGemmKernel:
neg_acc = acc_vec * cutlass.Float32(-1.0)
exp_neg = cute.exp(neg_acc)
sigmoid = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + exp_neg)
silu_result = (acc_vec * sigmoid).to(self.c_dtype)
silu_result = acc_vec * sigmoid
# Paper §4.2.3: gate component capped at swiglu_limit
if cutlass.const_expr(self.swiglu_limit > 0.0):
silu_result = cute.math.fmin(silu_result, cutlass.Float32(self.swiglu_limit))
silu_result = silu_result.to(self.c_dtype)
silu_gate_buf.store(silu_result)
# Keep acc_vec in BF16 (same type as the up branch)
acc_vec_bf16 = silu_result
if is_up:
# Paper §4.2.3: linear component clamped to [-swiglu_limit, swiglu_limit]
if cutlass.const_expr(self.swiglu_limit > 0.0):
acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, cutlass.Float32(-self.swiglu_limit)), cutlass.Float32(self.swiglu_limit))
# SwiGLU: silu(gate) * up
gate_vals = silu_gate_buf.load()
swiglu_result = (gate_vals * acc_vec.to(self.c_dtype))