NVFP4-1.1: fix cute.math.fmax -> cute.arch.fmax (correct CuTeDSL API)

This commit is contained in:
2026-05-28 03:48:51 +00:00
parent 60790564f0
commit 6f94925491
2 changed files with 8 additions and 8 deletions

View File

@@ -98,7 +98,7 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
# FP8 exponent = floor(log2(val)) + bias
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
fp8_exp = cute.math.fmin(fp8_exp, cutlass.Int32(15))
fp8_exp = cute.math.fmax(fp8_exp, cutlass.Int32(0))
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
# Mantissa for normal: (norm - 1) * 8, round
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
@@ -111,9 +111,9 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
fp8_exp = fp8_exp + cutlass.Int32(1)
mantissa = cute.math.fmin(mantissa, cutlass.Int32(7))
mantissa = cute.math.fmax(mantissa, cutlass.Int32(0))
mantissa = cute.arch.fmax(mantissa, cutlass.Int32(0))
fp8_exp = cute.math.fmin(fp8_exp, cutlass.Int32(15))
fp8_exp = cute.math.fmax(fp8_exp, cutlass.Int32(0))
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
# NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN.
# Saturate to max non-NaN (exp=15, mant=6 = 448.0).
@@ -127,7 +127,7 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
sub_m_f = clamped * cutlass.Float32(512.0)
sub_m = cutlass.Int32(sub_m_f) # round-to-nearest-even
sub_m = cute.math.fmin(sub_m, cutlass.Int32(7))
sub_m = cute.math.fmax(sub_m, cutlass.Int32(1))
sub_m = cute.arch.fmax(sub_m, cutlass.Int32(1))
mantissa = sub_m
fp8_exp = cutlass.Int32(0)
@@ -193,13 +193,13 @@ def quantize_e2m1_nibble(
if scale > cutlass.Float32(1e-8):
scaled = val / scale
abs_scaled = cute.math.fmax(scaled, cutlass.Float32(0.0) - scaled)
abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled)
abs_scaled = cute.math.fmin(abs_scaled, cutlass.Float32(6.0))
# half_step = round(|scaled| * 2) — round-to-nearest-even (matches __float2int_rn)
hs = cutlass.Int32(abs_scaled * cutlass.Float32(2.0))
hs = cute.math.fmin(hs, cutlass.Int32(12))
hs = cute.math.fmax(hs, cutlass.Int32(0))
hs = cute.arch.fmax(hs, cutlass.Int32(0))
idx = half_step_to_e2m1_idx(hs)

View File

@@ -48,8 +48,8 @@ def fp4_quant_test_kernel(
ptr = input_bf16.iterator + i * cutlass.Int32(2) # BF16=2 bytes
bf16_val = cute.arch.load(ptr, cutlass.BFloat16)
v = bf16_val.to(cutlass.Float32) / gs
a = cute.math.fmax(v, cutlass.Float32(0.0) - v)
amax = cute.math.fmax(amax, a)
a = cute.arch.fmax(v, cutlass.Float32(0.0) - v)
amax = cute.arch.fmax(amax, a)
# Block scale
bsf_f32 = amax / cutlass.Float32(6.0)