From 578d186c20d582a839616ee00549ccb62cb7c3eb Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 06:32:54 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20add=20SwiGLU=20clamping=20to=20fused=20k?= =?UTF-8?q?ernel=20(paper=20=C2=A74.2.3,=20CG-1)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fused SwiGLU kernel stored swiglu_limit but never applied it. Paper §4.2.3: gate capped at swiglu_limit, linear clamped to [-limit, +limit]. Non-fused reference path already applies clamping correctly. Fix: add fmin/fmax clamping in FP32 before BF16 conversion. --- dsv4/kernels/gemm/fused_swiglu.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dsv4/kernels/gemm/fused_swiglu.py b/dsv4/kernels/gemm/fused_swiglu.py index bb0e467c..94a19b0d 100644 --- a/dsv4/kernels/gemm/fused_swiglu.py +++ b/dsv4/kernels/gemm/fused_swiglu.py @@ -2187,11 +2187,18 @@ class FusedSwiGLUScaledGroupedGemmKernel: neg_acc = acc_vec * cutlass.Float32(-1.0) exp_neg = cute.exp(neg_acc) sigmoid = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + exp_neg) - silu_result = (acc_vec * sigmoid).to(self.c_dtype) + silu_result = acc_vec * sigmoid + # Paper §4.2.3: gate component capped at swiglu_limit + if cutlass.const_expr(self.swiglu_limit > 0.0): + silu_result = cute.math.fmin(silu_result, cutlass.Float32(self.swiglu_limit)) + silu_result = silu_result.to(self.c_dtype) silu_gate_buf.store(silu_result) # Keep acc_vec in BF16 (same type as the up branch) acc_vec_bf16 = silu_result if is_up: + # Paper §4.2.3: linear component clamped to [-swiglu_limit, swiglu_limit] + if cutlass.const_expr(self.swiglu_limit > 0.0): + acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, cutlass.Float32(-self.swiglu_limit)), cutlass.Float32(self.swiglu_limit)) # SwiGLU: silu(gate) * up gate_vals = silu_gate_buf.load() swiglu_result = (gate_vals * acc_vec.to(self.c_dtype))