diff --git a/dsv4/kernels/gemm/fp4_quant.py b/dsv4/kernels/gemm/fp4_quant.py index 1cb5b086..76e55f5a 100644 --- a/dsv4/kernels/gemm/fp4_quant.py +++ b/dsv4/kernels/gemm/fp4_quant.py @@ -135,8 +135,10 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: # FP8 exponent = floor(log2(val)) + bias fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS) - fp8_exp = cute.arch.fmin(fp8_exp, cutlass.Int32(15)) - fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0)) + if fp8_exp > cutlass.Int32(15): + fp8_exp = cutlass.Int32(15) + if fp8_exp < cutlass.Int32(0): + fp8_exp = cutlass.Int32(0) # Mantissa for normal: (norm - 1) * 8, round via threshold mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0) @@ -147,10 +149,17 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: mantissa = cutlass.Int32(0) fp8_exp = fp8_exp + cutlass.Int32(1) - mantissa = cute.arch.fmin(mantissa, cutlass.Int32(7)) - mantissa = cute.arch.fmax(mantissa, cutlass.Int32(0)) - fp8_exp = cute.arch.fmin(fp8_exp, cutlass.Int32(15)) - fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0)) + # Clamp mantissa to [0, 7] + if mantissa < cutlass.Int32(0): + mantissa = cutlass.Int32(0) + if mantissa > cutlass.Int32(7): + mantissa = cutlass.Int32(7) + + # Clamp exponent to [0, 15] + if fp8_exp < cutlass.Int32(0): + fp8_exp = cutlass.Int32(0) + if fp8_exp > cutlass.Int32(15): + fp8_exp = cutlass.Int32(15) # NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN. # Saturate to max non-NaN (exp=15, mant=6 = 448.0). @@ -163,8 +172,10 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: if fp8_exp < cutlass.Int32(1): sub_m_f = clamped * cutlass.Float32(512.0) sub_m = round_rne_u0_8(sub_m_f) - sub_m = cute.arch.fmin(sub_m, cutlass.Int32(7)) - sub_m = cute.arch.fmax(sub_m, cutlass.Int32(0)) + if sub_m < cutlass.Int32(0): + sub_m = cutlass.Int32(0) + if sub_m > cutlass.Int32(7): + sub_m = cutlass.Int32(7) mantissa = sub_m fp8_exp = cutlass.Int32(0)