From 5c746bbdf2fb3d948d025ec1a98d60468f026f26 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 08:15:46 +0000 Subject: [PATCH] fix: TensorSSA-compatible clamp in fused SwiGLU kernel cute.arch.fmin/fmax take scalar Float32, not TensorSSA. Replace with cute.where() and arithmetic for TensorSSA compatibility. Also changed subtile loop to unroll=1 for cute.where() compatibility. --- dsv4/kernels/gemm/fused_swiglu.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dsv4/kernels/gemm/fused_swiglu.py b/dsv4/kernels/gemm/fused_swiglu.py index 5d610f18..50ba91da 100644 --- a/dsv4/kernels/gemm/fused_swiglu.py +++ b/dsv4/kernels/gemm/fused_swiglu.py @@ -2137,7 +2137,7 @@ class FusedSwiGLUScaledGroupedGemmKernel: if cutlass.const_expr(self.fused_swiglu): silu_gate_buf = cute.make_rmem_tensor(tiled_copy_r2s.retile(tTR_rAcc).shape, self.c_dtype) - for subtile_idx in cutlass.range(subtile_cnt): + for subtile_idx in cutlass.range(subtile_cnt, unroll=1): # unroll=1: SwiGLU + clamp needs cute.arch.fmin/fmax (impure for vectorizer) real_subtile_idx = subtile_idx if cutlass.const_expr(self.overlapping_accum): if reverse_subtile: @@ -2198,16 +2198,22 @@ class FusedSwiGLUScaledGroupedGemmKernel: sigmoid = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + exp_neg) silu_result = acc_vec * sigmoid # Paper §4.2.3: gate component capped at swiglu_limit + # CuTe DSL clamp: min(x, limit) = x - max(x - limit, 0) + # Using cute.where() for TensorSSA-compatible conditional if cutlass.const_expr(self.swiglu_limit > 0.0): - silu_result = cute.arch.fmin(silu_result, cutlass.Float32(self.swiglu_limit)) + limit = cutlass.Float32(self.swiglu_limit) + excess = (silu_result - limit) * (silu_result > limit).float() + silu_result = silu_result - excess silu_result = silu_result.to(self.c_dtype) silu_gate_buf.store(silu_result) # Keep acc_vec in BF16 (same type as the up branch) acc_vec_bf16 = silu_result if is_up: # Paper §4.2.3: linear component clamped to [-swiglu_limit, swiglu_limit] + # CuTe DSL clamp using cute.where() for TensorSSA compatibility if cutlass.const_expr(self.swiglu_limit > 0.0): - acc_vec = cute.arch.fmin(cute.arch.fmax(acc_vec, cutlass.Float32(-self.swiglu_limit)), cutlass.Float32(self.swiglu_limit)) + limit = cutlass.Float32(self.swiglu_limit) + acc_vec = cute.where(acc_vec > limit, limit, cute.where(acc_vec < -limit, -limit, acc_vec)) # SwiGLU: silu(gate) * up gate_vals = silu_gate_buf.load() swiglu_result = (gate_vals * acc_vec.to(self.c_dtype))