fix: TensorSSA-compatible clamp in fused SwiGLU kernel
cute.arch.fmin/fmax take scalar Float32, not TensorSSA. Replace with cute.where() and arithmetic for TensorSSA compatibility. Also changed subtile loop to unroll=1 for cute.where() compatibility.
This commit is contained in:
@@ -2137,7 +2137,7 @@ class FusedSwiGLUScaledGroupedGemmKernel:
|
||||
if cutlass.const_expr(self.fused_swiglu):
|
||||
silu_gate_buf = cute.make_rmem_tensor(tiled_copy_r2s.retile(tTR_rAcc).shape, self.c_dtype)
|
||||
|
||||
for subtile_idx in cutlass.range(subtile_cnt):
|
||||
for subtile_idx in cutlass.range(subtile_cnt, unroll=1): # unroll=1: SwiGLU + clamp needs cute.arch.fmin/fmax (impure for vectorizer)
|
||||
real_subtile_idx = subtile_idx
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
if reverse_subtile:
|
||||
@@ -2198,16 +2198,22 @@ class FusedSwiGLUScaledGroupedGemmKernel:
|
||||
sigmoid = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + exp_neg)
|
||||
silu_result = acc_vec * sigmoid
|
||||
# Paper §4.2.3: gate component capped at swiglu_limit
|
||||
# CuTe DSL clamp: min(x, limit) = x - max(x - limit, 0)
|
||||
# Using cute.where() for TensorSSA-compatible conditional
|
||||
if cutlass.const_expr(self.swiglu_limit > 0.0):
|
||||
silu_result = cute.arch.fmin(silu_result, cutlass.Float32(self.swiglu_limit))
|
||||
limit = cutlass.Float32(self.swiglu_limit)
|
||||
excess = (silu_result - limit) * (silu_result > limit).float()
|
||||
silu_result = silu_result - excess
|
||||
silu_result = silu_result.to(self.c_dtype)
|
||||
silu_gate_buf.store(silu_result)
|
||||
# Keep acc_vec in BF16 (same type as the up branch)
|
||||
acc_vec_bf16 = silu_result
|
||||
if is_up:
|
||||
# Paper §4.2.3: linear component clamped to [-swiglu_limit, swiglu_limit]
|
||||
# CuTe DSL clamp using cute.where() for TensorSSA compatibility
|
||||
if cutlass.const_expr(self.swiglu_limit > 0.0):
|
||||
acc_vec = cute.arch.fmin(cute.arch.fmax(acc_vec, cutlass.Float32(-self.swiglu_limit)), cutlass.Float32(self.swiglu_limit))
|
||||
limit = cutlass.Float32(self.swiglu_limit)
|
||||
acc_vec = cute.where(acc_vec > limit, limit, cute.where(acc_vec < -limit, -limit, acc_vec))
|
||||
# SwiGLU: silu(gate) * up
|
||||
gate_vals = silu_gate_buf.load()
|
||||
swiglu_result = (gate_vals * acc_vec.to(self.c_dtype))
|
||||
|
||||
Reference in New Issue
Block a user