diff --git a/dsv4/kernels/gemm/fused_swiglu.py b/dsv4/kernels/gemm/fused_swiglu.py index 61fcffeb..5d610f18 100644 --- a/dsv4/kernels/gemm/fused_swiglu.py +++ b/dsv4/kernels/gemm/fused_swiglu.py @@ -2199,7 +2199,7 @@ class FusedSwiGLUScaledGroupedGemmKernel: silu_result = acc_vec * sigmoid # Paper §4.2.3: gate component capped at swiglu_limit if cutlass.const_expr(self.swiglu_limit > 0.0): - silu_result = cute.math.fmin(silu_result, cutlass.Float32(self.swiglu_limit)) + silu_result = cute.arch.fmin(silu_result, cutlass.Float32(self.swiglu_limit)) silu_result = silu_result.to(self.c_dtype) silu_gate_buf.store(silu_result) # Keep acc_vec in BF16 (same type as the up branch) @@ -2207,7 +2207,7 @@ class FusedSwiGLUScaledGroupedGemmKernel: if is_up: # Paper §4.2.3: linear component clamped to [-swiglu_limit, swiglu_limit] if cutlass.const_expr(self.swiglu_limit > 0.0): - acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, cutlass.Float32(-self.swiglu_limit)), cutlass.Float32(self.swiglu_limit)) + acc_vec = cute.arch.fmin(cute.arch.fmax(acc_vec, cutlass.Float32(-self.swiglu_limit)), cutlass.Float32(self.swiglu_limit)) # SwiGLU: silu(gate) * up gate_vals = silu_gate_buf.load() swiglu_result = (gate_vals * acc_vec.to(self.c_dtype))