From d2aa93aad7e455aa1fb8c838c7c4536b2176d0f0 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 04:30:06 +0000 Subject: [PATCH] =?UTF-8?q?NVFP4-1.1:=20fix=20Int32=20clamping=20=E2=80=94?= =?UTF-8?q?=20use=20comparisons=20instead=20of=20fmin/fmax=20(float-only?= =?UTF-8?q?=20ops)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/gemm/fp4_quant.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/dsv4/kernels/gemm/fp4_quant.py b/dsv4/kernels/gemm/fp4_quant.py index 1cb5b086..76e55f5a 100644 --- a/dsv4/kernels/gemm/fp4_quant.py +++ b/dsv4/kernels/gemm/fp4_quant.py @@ -135,8 +135,10 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: # FP8 exponent = floor(log2(val)) + bias fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS) - fp8_exp = cute.arch.fmin(fp8_exp, cutlass.Int32(15)) - fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0)) + if fp8_exp > cutlass.Int32(15): + fp8_exp = cutlass.Int32(15) + if fp8_exp < cutlass.Int32(0): + fp8_exp = cutlass.Int32(0) # Mantissa for normal: (norm - 1) * 8, round via threshold mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0) @@ -147,10 +149,17 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: mantissa = cutlass.Int32(0) fp8_exp = fp8_exp + cutlass.Int32(1) - mantissa = cute.arch.fmin(mantissa, cutlass.Int32(7)) - mantissa = cute.arch.fmax(mantissa, cutlass.Int32(0)) - fp8_exp = cute.arch.fmin(fp8_exp, cutlass.Int32(15)) - fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0)) + # Clamp mantissa to [0, 7] + if mantissa < cutlass.Int32(0): + mantissa = cutlass.Int32(0) + if mantissa > cutlass.Int32(7): + mantissa = cutlass.Int32(7) + + # Clamp exponent to [0, 15] + if fp8_exp < cutlass.Int32(0): + fp8_exp = cutlass.Int32(0) + if fp8_exp > cutlass.Int32(15): + fp8_exp = cutlass.Int32(15) # NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN. # Saturate to max non-NaN (exp=15, mant=6 = 448.0). @@ -163,8 +172,10 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: if fp8_exp < cutlass.Int32(1): sub_m_f = clamped * cutlass.Float32(512.0) sub_m = round_rne_u0_8(sub_m_f) - sub_m = cute.arch.fmin(sub_m, cutlass.Int32(7)) - sub_m = cute.arch.fmax(sub_m, cutlass.Int32(0)) + if sub_m < cutlass.Int32(0): + sub_m = cutlass.Int32(0) + if sub_m > cutlass.Int32(7): + sub_m = cutlass.Int32(7) mantissa = sub_m fp8_exp = cutlass.Int32(0)