NVFP4-1.1: fix cute.math.fmax -> cute.arch.fmax (correct CuTeDSL API)
This commit is contained in:
@@ -98,7 +98,7 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
# FP8 exponent = floor(log2(val)) + bias
|
||||
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
|
||||
fp8_exp = cute.math.fmin(fp8_exp, cutlass.Int32(15))
|
||||
fp8_exp = cute.math.fmax(fp8_exp, cutlass.Int32(0))
|
||||
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
|
||||
|
||||
# Mantissa for normal: (norm - 1) * 8, round
|
||||
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
|
||||
@@ -111,9 +111,9 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
fp8_exp = fp8_exp + cutlass.Int32(1)
|
||||
|
||||
mantissa = cute.math.fmin(mantissa, cutlass.Int32(7))
|
||||
mantissa = cute.math.fmax(mantissa, cutlass.Int32(0))
|
||||
mantissa = cute.arch.fmax(mantissa, cutlass.Int32(0))
|
||||
fp8_exp = cute.math.fmin(fp8_exp, cutlass.Int32(15))
|
||||
fp8_exp = cute.math.fmax(fp8_exp, cutlass.Int32(0))
|
||||
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
|
||||
|
||||
# NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN.
|
||||
# Saturate to max non-NaN (exp=15, mant=6 = 448.0).
|
||||
@@ -127,7 +127,7 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
sub_m_f = clamped * cutlass.Float32(512.0)
|
||||
sub_m = cutlass.Int32(sub_m_f) # round-to-nearest-even
|
||||
sub_m = cute.math.fmin(sub_m, cutlass.Int32(7))
|
||||
sub_m = cute.math.fmax(sub_m, cutlass.Int32(1))
|
||||
sub_m = cute.arch.fmax(sub_m, cutlass.Int32(1))
|
||||
mantissa = sub_m
|
||||
fp8_exp = cutlass.Int32(0)
|
||||
|
||||
@@ -193,13 +193,13 @@ def quantize_e2m1_nibble(
|
||||
|
||||
if scale > cutlass.Float32(1e-8):
|
||||
scaled = val / scale
|
||||
abs_scaled = cute.math.fmax(scaled, cutlass.Float32(0.0) - scaled)
|
||||
abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled)
|
||||
abs_scaled = cute.math.fmin(abs_scaled, cutlass.Float32(6.0))
|
||||
|
||||
# half_step = round(|scaled| * 2) — round-to-nearest-even (matches __float2int_rn)
|
||||
hs = cutlass.Int32(abs_scaled * cutlass.Float32(2.0))
|
||||
hs = cute.math.fmin(hs, cutlass.Int32(12))
|
||||
hs = cute.math.fmax(hs, cutlass.Int32(0))
|
||||
hs = cute.arch.fmax(hs, cutlass.Int32(0))
|
||||
|
||||
idx = half_step_to_e2m1_idx(hs)
|
||||
|
||||
|
||||
@@ -48,8 +48,8 @@ def fp4_quant_test_kernel(
|
||||
ptr = input_bf16.iterator + i * cutlass.Int32(2) # BF16=2 bytes
|
||||
bf16_val = cute.arch.load(ptr, cutlass.BFloat16)
|
||||
v = bf16_val.to(cutlass.Float32) / gs
|
||||
a = cute.math.fmax(v, cutlass.Float32(0.0) - v)
|
||||
amax = cute.math.fmax(amax, a)
|
||||
a = cute.arch.fmax(v, cutlass.Float32(0.0) - v)
|
||||
amax = cute.arch.fmax(amax, a)
|
||||
|
||||
# Block scale
|
||||
bsf_f32 = amax / cutlass.Float32(6.0)
|
||||
|
||||
Reference in New Issue
Block a user