fix: cute.math.fmin/fmax → cute.arch.fmin/fmax in fused SwiGLU kernel
cute.math has no fmin/fmax. cute.arch does (register-level ops). README constraint #4: use cute.arch.fmax inside plain range(), not vectorize=True.
This commit is contained in:
@@ -2199,7 +2199,7 @@ class FusedSwiGLUScaledGroupedGemmKernel:
|
|||||||
silu_result = acc_vec * sigmoid
|
silu_result = acc_vec * sigmoid
|
||||||
# Paper §4.2.3: gate component capped at swiglu_limit
|
# Paper §4.2.3: gate component capped at swiglu_limit
|
||||||
if cutlass.const_expr(self.swiglu_limit > 0.0):
|
if cutlass.const_expr(self.swiglu_limit > 0.0):
|
||||||
silu_result = cute.math.fmin(silu_result, cutlass.Float32(self.swiglu_limit))
|
silu_result = cute.arch.fmin(silu_result, cutlass.Float32(self.swiglu_limit))
|
||||||
silu_result = silu_result.to(self.c_dtype)
|
silu_result = silu_result.to(self.c_dtype)
|
||||||
silu_gate_buf.store(silu_result)
|
silu_gate_buf.store(silu_result)
|
||||||
# Keep acc_vec in BF16 (same type as the up branch)
|
# Keep acc_vec in BF16 (same type as the up branch)
|
||||||
@@ -2207,7 +2207,7 @@ class FusedSwiGLUScaledGroupedGemmKernel:
|
|||||||
if is_up:
|
if is_up:
|
||||||
# Paper §4.2.3: linear component clamped to [-swiglu_limit, swiglu_limit]
|
# Paper §4.2.3: linear component clamped to [-swiglu_limit, swiglu_limit]
|
||||||
if cutlass.const_expr(self.swiglu_limit > 0.0):
|
if cutlass.const_expr(self.swiglu_limit > 0.0):
|
||||||
acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, cutlass.Float32(-self.swiglu_limit)), cutlass.Float32(self.swiglu_limit))
|
acc_vec = cute.arch.fmin(cute.arch.fmax(acc_vec, cutlass.Float32(-self.swiglu_limit)), cutlass.Float32(self.swiglu_limit))
|
||||||
# SwiGLU: silu(gate) * up
|
# SwiGLU: silu(gate) * up
|
||||||
gate_vals = silu_gate_buf.load()
|
gate_vals = silu_gate_buf.load()
|
||||||
swiglu_result = (gate_vals * acc_vec.to(self.c_dtype))
|
swiglu_result = (gate_vals * acc_vec.to(self.c_dtype))
|
||||||
|
|||||||
Reference in New Issue
Block a user