NVFP4-1.1: fix Int32 clamping — use comparisons instead of fmin/fmax (float-only ops)

This commit is contained in:
2026-05-28 04:30:06 +00:00
parent accc66741d
commit d2aa93aad7

View File

@@ -135,8 +135,10 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
# FP8 exponent = floor(log2(val)) + bias
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
fp8_exp = cute.arch.fmin(fp8_exp, cutlass.Int32(15))
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
if fp8_exp > cutlass.Int32(15):
fp8_exp = cutlass.Int32(15)
if fp8_exp < cutlass.Int32(0):
fp8_exp = cutlass.Int32(0)
# Mantissa for normal: (norm - 1) * 8, round via threshold
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
@@ -147,10 +149,17 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
mantissa = cutlass.Int32(0)
fp8_exp = fp8_exp + cutlass.Int32(1)
mantissa = cute.arch.fmin(mantissa, cutlass.Int32(7))
mantissa = cute.arch.fmax(mantissa, cutlass.Int32(0))
fp8_exp = cute.arch.fmin(fp8_exp, cutlass.Int32(15))
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
# Clamp mantissa to [0, 7]
if mantissa < cutlass.Int32(0):
mantissa = cutlass.Int32(0)
if mantissa > cutlass.Int32(7):
mantissa = cutlass.Int32(7)
# Clamp exponent to [0, 15]
if fp8_exp < cutlass.Int32(0):
fp8_exp = cutlass.Int32(0)
if fp8_exp > cutlass.Int32(15):
fp8_exp = cutlass.Int32(15)
# NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN.
# Saturate to max non-NaN (exp=15, mant=6 = 448.0).
@@ -163,8 +172,10 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
if fp8_exp < cutlass.Int32(1):
sub_m_f = clamped * cutlass.Float32(512.0)
sub_m = round_rne_u0_8(sub_m_f)
sub_m = cute.arch.fmin(sub_m, cutlass.Int32(7))
sub_m = cute.arch.fmax(sub_m, cutlass.Int32(0))
if sub_m < cutlass.Int32(0):
sub_m = cutlass.Int32(0)
if sub_m > cutlass.Int32(7):
sub_m = cutlass.Int32(7)
mantissa = sub_m
fp8_exp = cutlass.Int32(0)