diff --git a/dsv4/kernels/gemm/fused_swiglu.py b/dsv4/kernels/gemm/fused_swiglu.py index bb0e467c..94a19b0d 100644 --- a/dsv4/kernels/gemm/fused_swiglu.py +++ b/dsv4/kernels/gemm/fused_swiglu.py @@ -2187,11 +2187,18 @@ class FusedSwiGLUScaledGroupedGemmKernel: neg_acc = acc_vec * cutlass.Float32(-1.0) exp_neg = cute.exp(neg_acc) sigmoid = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + exp_neg) - silu_result = (acc_vec * sigmoid).to(self.c_dtype) + silu_result = acc_vec * sigmoid + # Paper §4.2.3: gate component capped at swiglu_limit + if cutlass.const_expr(self.swiglu_limit > 0.0): + silu_result = cute.math.fmin(silu_result, cutlass.Float32(self.swiglu_limit)) + silu_result = silu_result.to(self.c_dtype) silu_gate_buf.store(silu_result) # Keep acc_vec in BF16 (same type as the up branch) acc_vec_bf16 = silu_result if is_up: + # Paper §4.2.3: linear component clamped to [-swiglu_limit, swiglu_limit] + if cutlass.const_expr(self.swiglu_limit > 0.0): + acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, cutlass.Float32(-self.swiglu_limit)), cutlass.Float32(self.swiglu_limit)) # SwiGLU: silu(gate) * up gate_vals = silu_gate_buf.load() swiglu_result = (gate_vals * acc_vec.to(self.c_dtype))