NVFP4-1.1: fix Int32 clamping — use comparisons instead of fmin/fmax (float-only ops)
This commit is contained in:
@@ -135,8 +135,10 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
|
||||
# FP8 exponent = floor(log2(val)) + bias
|
||||
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
|
||||
fp8_exp = cute.arch.fmin(fp8_exp, cutlass.Int32(15))
|
||||
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
|
||||
if fp8_exp > cutlass.Int32(15):
|
||||
fp8_exp = cutlass.Int32(15)
|
||||
if fp8_exp < cutlass.Int32(0):
|
||||
fp8_exp = cutlass.Int32(0)
|
||||
|
||||
# Mantissa for normal: (norm - 1) * 8, round via threshold
|
||||
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
|
||||
@@ -147,10 +149,17 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
mantissa = cutlass.Int32(0)
|
||||
fp8_exp = fp8_exp + cutlass.Int32(1)
|
||||
|
||||
mantissa = cute.arch.fmin(mantissa, cutlass.Int32(7))
|
||||
mantissa = cute.arch.fmax(mantissa, cutlass.Int32(0))
|
||||
fp8_exp = cute.arch.fmin(fp8_exp, cutlass.Int32(15))
|
||||
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
|
||||
# Clamp mantissa to [0, 7]
|
||||
if mantissa < cutlass.Int32(0):
|
||||
mantissa = cutlass.Int32(0)
|
||||
if mantissa > cutlass.Int32(7):
|
||||
mantissa = cutlass.Int32(7)
|
||||
|
||||
# Clamp exponent to [0, 15]
|
||||
if fp8_exp < cutlass.Int32(0):
|
||||
fp8_exp = cutlass.Int32(0)
|
||||
if fp8_exp > cutlass.Int32(15):
|
||||
fp8_exp = cutlass.Int32(15)
|
||||
|
||||
# NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN.
|
||||
# Saturate to max non-NaN (exp=15, mant=6 = 448.0).
|
||||
@@ -163,8 +172,10 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
if fp8_exp < cutlass.Int32(1):
|
||||
sub_m_f = clamped * cutlass.Float32(512.0)
|
||||
sub_m = round_rne_u0_8(sub_m_f)
|
||||
sub_m = cute.arch.fmin(sub_m, cutlass.Int32(7))
|
||||
sub_m = cute.arch.fmax(sub_m, cutlass.Int32(0))
|
||||
if sub_m < cutlass.Int32(0):
|
||||
sub_m = cutlass.Int32(0)
|
||||
if sub_m > cutlass.Int32(7):
|
||||
sub_m = cutlass.Int32(7)
|
||||
mantissa = sub_m
|
||||
fp8_exp = cutlass.Int32(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user