From 19afa52e807cb580438657ca452674e8eb475a70 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 08:16:41 +0000 Subject: [PATCH] fix: use cute.where() directly for clamp in fused SwiGLU (silu_result > limit).float() doesn't work on TensorSSA. cute.where(cond, true_val, false_val) is the correct TensorSSA API. --- dsv4/kernels/gemm/fused_swiglu.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/gemm/fused_swiglu.py b/dsv4/kernels/gemm/fused_swiglu.py index 50ba91da..1c4b1d07 100644 --- a/dsv4/kernels/gemm/fused_swiglu.py +++ b/dsv4/kernels/gemm/fused_swiglu.py @@ -2198,19 +2198,16 @@ class FusedSwiGLUScaledGroupedGemmKernel: sigmoid = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + exp_neg) silu_result = acc_vec * sigmoid # Paper §4.2.3: gate component capped at swiglu_limit - # CuTe DSL clamp: min(x, limit) = x - max(x - limit, 0) - # Using cute.where() for TensorSSA-compatible conditional + # CuTe DSL clamp: min(x, limit) = cute.where(x > limit, limit, x) if cutlass.const_expr(self.swiglu_limit > 0.0): limit = cutlass.Float32(self.swiglu_limit) - excess = (silu_result - limit) * (silu_result > limit).float() - silu_result = silu_result - excess + silu_result = cute.where(silu_result > limit, limit, silu_result) silu_result = silu_result.to(self.c_dtype) silu_gate_buf.store(silu_result) # Keep acc_vec in BF16 (same type as the up branch) acc_vec_bf16 = silu_result if is_up: # Paper §4.2.3: linear component clamped to [-swiglu_limit, swiglu_limit] - # CuTe DSL clamp using cute.where() for TensorSSA compatibility if cutlass.const_expr(self.swiglu_limit > 0.0): limit = cutlass.Float32(self.swiglu_limit) acc_vec = cute.where(acc_vec > limit, limit, cute.where(acc_vec < -limit, -limit, acc_vec))