From 6f94925491adb91c92d9c8a7c01d1b0df2a0776f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 03:48:51 +0000 Subject: [PATCH] NVFP4-1.1: fix cute.math.fmax -> cute.arch.fmax (correct CuTeDSL API) --- dsv4/kernels/gemm/fp4_quant.py | 12 ++++++------ tests/unit/test_nvfp4_1_1_quant.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/dsv4/kernels/gemm/fp4_quant.py b/dsv4/kernels/gemm/fp4_quant.py index 796584b2..59dd2d14 100644 --- a/dsv4/kernels/gemm/fp4_quant.py +++ b/dsv4/kernels/gemm/fp4_quant.py @@ -98,7 +98,7 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: # FP8 exponent = floor(log2(val)) + bias fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS) fp8_exp = cute.math.fmin(fp8_exp, cutlass.Int32(15)) - fp8_exp = cute.math.fmax(fp8_exp, cutlass.Int32(0)) + fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0)) # Mantissa for normal: (norm - 1) * 8, round mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0) @@ -111,9 +111,9 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: fp8_exp = fp8_exp + cutlass.Int32(1) mantissa = cute.math.fmin(mantissa, cutlass.Int32(7)) - mantissa = cute.math.fmax(mantissa, cutlass.Int32(0)) + mantissa = cute.arch.fmax(mantissa, cutlass.Int32(0)) fp8_exp = cute.math.fmin(fp8_exp, cutlass.Int32(15)) - fp8_exp = cute.math.fmax(fp8_exp, cutlass.Int32(0)) + fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0)) # NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN. # Saturate to max non-NaN (exp=15, mant=6 = 448.0). @@ -127,7 +127,7 @@ def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: sub_m_f = clamped * cutlass.Float32(512.0) sub_m = cutlass.Int32(sub_m_f) # round-to-nearest-even sub_m = cute.math.fmin(sub_m, cutlass.Int32(7)) - sub_m = cute.math.fmax(sub_m, cutlass.Int32(1)) + sub_m = cute.arch.fmax(sub_m, cutlass.Int32(1)) mantissa = sub_m fp8_exp = cutlass.Int32(0) @@ -193,13 +193,13 @@ def quantize_e2m1_nibble( if scale > cutlass.Float32(1e-8): scaled = val / scale - abs_scaled = cute.math.fmax(scaled, cutlass.Float32(0.0) - scaled) + abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled) abs_scaled = cute.math.fmin(abs_scaled, cutlass.Float32(6.0)) # half_step = round(|scaled| * 2) — round-to-nearest-even (matches __float2int_rn) hs = cutlass.Int32(abs_scaled * cutlass.Float32(2.0)) hs = cute.math.fmin(hs, cutlass.Int32(12)) - hs = cute.math.fmax(hs, cutlass.Int32(0)) + hs = cute.arch.fmax(hs, cutlass.Int32(0)) idx = half_step_to_e2m1_idx(hs) diff --git a/tests/unit/test_nvfp4_1_1_quant.py b/tests/unit/test_nvfp4_1_1_quant.py index f2882760..6ad29469 100644 --- a/tests/unit/test_nvfp4_1_1_quant.py +++ b/tests/unit/test_nvfp4_1_1_quant.py @@ -48,8 +48,8 @@ def fp4_quant_test_kernel( ptr = input_bf16.iterator + i * cutlass.Int32(2) # BF16=2 bytes bf16_val = cute.arch.load(ptr, cutlass.BFloat16) v = bf16_val.to(cutlass.Float32) / gs - a = cute.math.fmax(v, cutlass.Float32(0.0) - v) - amax = cute.math.fmax(amax, a) + a = cute.arch.fmax(v, cutlass.Float32(0.0) - v) + amax = cute.arch.fmax(amax, a) # Block scale bsf_f32 = amax / cutlass.Float32(6.0)