fix: use cute.where() directly for clamp in fused SwiGLU

(silu_result > limit).float() doesn't work on TensorSSA.
cute.where(cond, true_val, false_val) is the correct TensorSSA API.
This commit is contained in:
2026-06-02 08:16:41 +00:00
parent 5c746bbdf2
commit 19afa52e80

View File

@@ -2198,19 +2198,16 @@ class FusedSwiGLUScaledGroupedGemmKernel:
sigmoid = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + exp_neg)
silu_result = acc_vec * sigmoid
# Paper §4.2.3: gate component capped at swiglu_limit
# CuTe DSL clamp: min(x, limit) = x - max(x - limit, 0)
# Using cute.where() for TensorSSA-compatible conditional
# CuTe DSL clamp: min(x, limit) = cute.where(x > limit, limit, x)
if cutlass.const_expr(self.swiglu_limit > 0.0):
limit = cutlass.Float32(self.swiglu_limit)
excess = (silu_result - limit) * (silu_result > limit).float()
silu_result = silu_result - excess
silu_result = cute.where(silu_result > limit, limit, silu_result)
silu_result = silu_result.to(self.c_dtype)
silu_gate_buf.store(silu_result)
# Keep acc_vec in BF16 (same type as the up branch)
acc_vec_bf16 = silu_result
if is_up:
# Paper §4.2.3: linear component clamped to [-swiglu_limit, swiglu_limit]
# CuTe DSL clamp using cute.where() for TensorSSA compatibility
if cutlass.const_expr(self.swiglu_limit > 0.0):
limit = cutlass.Float32(self.swiglu_limit)
acc_vec = cute.where(acc_vec > limit, limit, cute.where(acc_vec < -limit, -limit, acc_vec))