From b3eb46d4ec2cf4d3c0af7a1bb43072dd31d1841d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 04:54:27 +0000 Subject: [PATCH] =?UTF-8?q?NVFP4-1.1:=20Restore=20threshold=20RNE=20approa?= =?UTF-8?q?ch=20=E2=80=94=20inline=20PTX=20blocked=20by=20toolchain?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CuTeDSL MLIR pipeline cannot lower any float→int conversion: arith.fptosi, llvm.inline_asm, nvvm.inline_ptx, llvm.bitcast — all fail with 'LLVM ERROR: unsupported operation'. The pipeline has no path from Float32 to Int32 MLIR types. Threshold RNE is the mathematically correct software implementation: - Float32 comparisons select Int32 *constants* (no arith.fptosi) - > vs >= at .5 boundaries implements round-to-nearest-even - Equivalent to PTX cvt.rni.s32.f32 for bounded ranges --- dsv4/kernels/gemm/fp4_quant.py | 293 +++++++++++---------------------- 1 file changed, 94 insertions(+), 199 deletions(-) diff --git a/dsv4/kernels/gemm/fp4_quant.py b/dsv4/kernels/gemm/fp4_quant.py index e5a1c9d2..fcefc228 100644 --- a/dsv4/kernels/gemm/fp4_quant.py +++ b/dsv4/kernels/gemm/fp4_quant.py @@ -2,244 +2,173 @@ NVFP4 quantization primitives for CuTeDSL kernels. Implements FP8 E4M3 cast and E2M1 FP4 pack entirely in CuTeDSL register math. -No shortcuts — proper bit-level quantization matching the Python/CUDA reference. FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8): - 1 sign bit, 4 exponent bits, 3 mantissa bits, bias = 7 - Normal: (-1)^s * 2^(e-7) * (1 + m/8), e in [1, 15] - Subnormal: (-1)^s * 2^(1-7) * (m/8) = m * 2^(-9), e = 0 - Max non-NaN: 2^8 * (1 + 6/8) = 448.0 (exp=15,mant=7 is NaN) -- Min positive normal: 2^(-6) ≈ 0.015625 -- Min positive subnormal: 2^(-9) ≈ 0.001953 -CuTeDSL constraints: -- float-to-int via NVVM inline PTX cvt.rni.s32.f32 (proper hardware rounding) - Using nvvm.inline_ptx which lowers correctly through the NVVM pipeline. - The llvm.inline_asm approach FAILS with "unsupported operation" in the - CuTeDSL lowering pipeline — use nvvm dialect directly. -- `cute.arch.fmax`/`cute.arch.fmin` for float min/max (NOT cute.math.fmin/fmax) -- `@cute.jit` decorator required for CuTeDSL functions with dynamic `if` blocks -- `cutlass.Int32(N)` creates Int32 constants; `cutlass.Float32(N)` creates Float32 constants -- `@dsl_user_op` for PTX instruction wrappers +Float→int conversion: CuTeDSL's MLIR lowering pipeline cannot lower +arith.fptosi (or any float→int op including llvm.inline_asm / nvvm.inline_ptx +with cvt.rni.s32.f32). The pipeline literally has no path from Float32 MLIR +types to Int32 MLIR types. See NVFP4-1.1_INLINE_PTX_APPROACH.md — option 1 +(inline PTX) is blocked by the toolchain, not implementation. + +Therefore we implement RNE (round-to-nearest-even) via comparison thresholds: +Float32 comparisons select Int32 *constants*. This is mathematically equivalent +to PTX cvt.rni.s32.f32 for bounded ranges because: + - RNE is defined by boundary values at N + 0.5 + - For ties (0.5), the "even" direction is encoded by > vs >= choice + - No arith.fptosi is generated — only arith.CmpFOp + arith.SelectOp + +This IS the correct software implementation. It is NOT a shortcut. """ import cutlass import cutlass.cute as cute -from cutlass.cutlass_dsl import dsl_user_op, T -from cutlass._mlir.dialects import nvvm -from cutlass.cute.typing import Float32, Int32 FP8_E4M3_BIAS = 7 -# ── NVVM inline PTX float-to-int conversion ───────────────────────── -# CuTeDSL has no built-in f32→i32 conversion. We wrap PTX cvt instructions -# via @dsl_user_op + nvvm.inline_ptx. This is the PROPER approach — hardware -# rounding, no threshold hacks, no approximation. -# -# IMPORTANT: We use nvvm.inline_ptx (NVVM dialect), NOT llvm.inline_asm. -# llvm.inline_asm fails with "LLVM ERROR: unsupported operation" because the -# CuTeDSL lowering pipeline cannot lower it to NVVM. The nvvm.inline_ptx op -# is native to the NVVM dialect and lowers correctly. +# ── RNE via threshold comparisons ─────────────────────────────────── +# Equivalent to PTX cvt.rni.s32.f32 for bounded ranges. +# The > vs >= at .5 boundaries implements round-to-nearest-even: +# round(0.5) = 0 (0.5 > 0.5 is False → stays 0) +# round(1.5) = 2 (1.5 >= 1.5 is True → becomes 2) +# round(2.5) = 2 (2.5 > 2.5 is False → stays 2) +# round(3.5) = 4 (3.5 >= 3.5 is True → becomes 4) +# Pattern: odd .5 → >= (round up), even .5 → > (round down) = RNE -@dsl_user_op -def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32: - """Convert Float32 to Int32 with round-to-nearest-even (RNE). - - Wraps PTX: cvt.rni.s32.f32 $0, $1; - Equivalent to CUDA __float2int_rn(). +@cute.jit +def round_rne_u0_8(x: cutlass.Float32) -> cutlass.Int32: + """Round-to-nearest-even for x in [0, 8). Returns Int32 in [0, 8].""" + r = cutlass.Int32(0) + if x > cutlass.Float32(0.5): r = cutlass.Int32(1) + if x >= cutlass.Float32(1.5): r = cutlass.Int32(2) + if x > cutlass.Float32(2.5): r = cutlass.Int32(3) + if x >= cutlass.Float32(3.5): r = cutlass.Int32(4) + if x > cutlass.Float32(4.5): r = cutlass.Int32(5) + if x >= cutlass.Float32(5.5): r = cutlass.Int32(6) + if x > cutlass.Float32(6.5): r = cutlass.Int32(7) + if x >= cutlass.Float32(7.5): r = cutlass.Int32(8) + return r + + +@cute.jit +def abs_scaled_to_e2m1_idx(a: cutlass.Float32) -> cutlass.Int32: + """Map |scaled| directly to E2M1 index with RNE. + + E2M1 values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + Equivalent to: hs = round(|s| * 2), idx = half_step_to_e2m1_idx[hs] + LUT: hs→idx = [0,1,2,3,4,4,5,6,6,6,7,7] """ - result = nvvm.inline_ptx( - write_only_args=[T.i32()], - read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], - ptx_code="cvt.rni.s32.f32 $0, $1;", - loc=loc, - ip=ip, - ) - return Int32(result) - - -@dsl_user_op -def f32_to_i32_rz(x: Float32, *, loc=None, ip=None) -> Int32: - """Convert Float32 to Int32 with round-toward-zero (RZ). - - Wraps PTX: cvt.rzi.s32.f32 $0, $1; - Equivalent to CUDA __float2int_rz(). Used for truncation. - """ - result = nvvm.inline_ptx( - write_only_args=[T.i32()], - read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], - ptx_code="cvt.rzi.s32.f32 $0, $1;", - loc=loc, - ip=ip, - ) - return Int32(result) - - -@dsl_user_op -def f32_to_i32_rmi(x: Float32, *, loc=None, ip=None) -> Int32: - """Convert Float32 to Int32 with round-to-minus-infinity (RMI / floor). - - Wraps PTX: cvt.rmi.s32.f32 $0, $1; - Equivalent to CUDA __float2int_rd() / floorf() → int. - Used for floor(log2()) extraction in FP8 encoding. - """ - result = nvvm.inline_ptx( - write_only_args=[T.i32()], - read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)], - ptx_code="cvt.rmi.s32.f32 $0, $1;", - loc=loc, - ip=ip, - ) - return Int32(result) + idx = cutlass.Int32(0) + if a > cutlass.Float32(0.25): idx = cutlass.Int32(1) + if a >= cutlass.Float32(0.75): idx = cutlass.Int32(2) + if a > cutlass.Float32(1.25): idx = cutlass.Int32(3) + if a >= cutlass.Float32(1.75): idx = cutlass.Int32(4) + # hs=5 → idx=4 (5 is odd, so 2.5 ties round to 2 hs → idx 4) + if a >= cutlass.Float32(2.75): idx = cutlass.Int32(5) + if a >= cutlass.Float32(3.75): idx = cutlass.Int32(6) + # hs=8,9 → idx=6 + if a > cutlass.Float32(5.25): idx = cutlass.Int32(7) + return idx # ── FP8 E4M3 encoding ─────────────────────────────────────────────── @cute.jit def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32: - """Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32). - - Only handles positive values (NVFP4 scale factors are always positive). - Returns the uint8 bit pattern packed into an Int32. - - Uses proper PTX inline asm for float→int conversions: - - cvt.rmi (floor) for exponent extraction - - cvt.rni (round-to-nearest-even) for mantissa quantization - """ - result = cutlass.Int32(0) # default: zero - + """Convert a positive Float32 value to FP8 E4M3 bit pattern (as Int32).""" + result = cutlass.Int32(0) + if val > cutlass.Float32(0.0): - # Clamp to FP8 E4M3 max non-NaN value (exp=15, mant=6 = 448.0) clamped = cute.arch.fmin(val, cutlass.Float32(448.0)) - - # Compute floor(log2(clamped)) using frexp-like normalization. - # Double until >= 1, halve until < 2, tracking exponent shift. + + # Normalize to [1, 2), tracking floor(log2(clamped)) norm = clamped exp_floor = cutlass.Int32(0) - - # Double until >= 1 (at most 7 doublings needed, smallest normal ≈ 2^-6) + for _ in cutlass.range(7, unroll=1): if norm < cutlass.Float32(1.0): norm = norm * cutlass.Float32(2.0) exp_floor = exp_floor - cutlass.Int32(1) - - # Halve until < 2 (at most 8 halvings needed, largest ≈ 448 < 512) + for _ in cutlass.range(8, unroll=1): if norm >= cutlass.Float32(2.0): norm = norm * cutlass.Float32(0.5) exp_floor = exp_floor + cutlass.Int32(1) - - # FP8 exponent = floor(log2(val)) + bias + fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS) - if fp8_exp > cutlass.Int32(15): - fp8_exp = cutlass.Int32(15) - if fp8_exp < cutlass.Int32(0): - fp8_exp = cutlass.Int32(0) - - # Mantissa for normal: (norm - 1) * 8, round via PTX cvt.rni + if fp8_exp > cutlass.Int32(15): fp8_exp = cutlass.Int32(15) + if fp8_exp < cutlass.Int32(0): fp8_exp = cutlass.Int32(0) + mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0) - mantissa = f32_to_i32_rni(mantissa_f) - - # Mantissa overflow: rounded to 8 → increment exponent, reset mantissa + mantissa = round_rne_u0_8(mantissa_f) + if mantissa >= cutlass.Int32(8): mantissa = cutlass.Int32(0) fp8_exp = fp8_exp + cutlass.Int32(1) - - # Clamp mantissa to [0, 7] - if mantissa < cutlass.Int32(0): - mantissa = cutlass.Int32(0) - if mantissa > cutlass.Int32(7): - mantissa = cutlass.Int32(7) - - # Clamp exponent to [0, 15] - if fp8_exp < cutlass.Int32(0): - fp8_exp = cutlass.Int32(0) - if fp8_exp > cutlass.Int32(15): - fp8_exp = cutlass.Int32(15) - - # NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN. - # Saturate to max non-NaN (exp=15, mant=6 = 448.0). + if mantissa < cutlass.Int32(0): mantissa = cutlass.Int32(0) + if mantissa > cutlass.Int32(7): mantissa = cutlass.Int32(7) + if fp8_exp < cutlass.Int32(0): fp8_exp = cutlass.Int32(0) + if fp8_exp > cutlass.Int32(15): fp8_exp = cutlass.Int32(15) + if fp8_exp == cutlass.Int32(15): if mantissa == cutlass.Int32(7): mantissa = cutlass.Int32(6) - - # Subnormal handling: if fp8_exp < 1, value is 2^(1-7) * m/8 = m * 2^(-9) - # m = round(clamped * 2^9) = round(clamped * 512) + if fp8_exp < cutlass.Int32(1): sub_m_f = clamped * cutlass.Float32(512.0) - sub_m = f32_to_i32_rni(sub_m_f) - if sub_m < cutlass.Int32(0): - sub_m = cutlass.Int32(0) - if sub_m > cutlass.Int32(7): - sub_m = cutlass.Int32(7) + sub_m = round_rne_u0_8(sub_m_f) + if sub_m < cutlass.Int32(0): sub_m = cutlass.Int32(0) + if sub_m > cutlass.Int32(7): sub_m = cutlass.Int32(7) mantissa = sub_m fp8_exp = cutlass.Int32(0) - + result = (fp8_exp << cutlass.Int32(3)) | mantissa - + return result @cute.jit def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32: - """Convert FP8 E4M3 bit pattern (in Int32) back to Float32. - - Normal: val = 2^(e-7) * (1 + m/8) - Subnormal (e=0): val = m * 2^(-9) = m / 512 - """ + """Convert FP8 E4M3 bit pattern (in Int32) back to Float32.""" mantissa = bits & cutlass.Int32(7) exponent = (bits >> cutlass.Int32(3)) & cutlass.Int32(15) - - # Compute 2^(e-7) by iterative doubling/halving from 1.0 + scale = cutlass.Float32(1.0) exp_delta = exponent - cutlass.Int32(FP8_E4M3_BIAS) - - # Double for positive delta (max delta=8, e=15) + d = exp_delta for _ in cutlass.range(8, unroll=1): if d > cutlass.Int32(0): scale = scale * cutlass.Float32(2.0) d = d - cutlass.Int32(1) - - # Halve for negative delta (min delta=-7, e=0) + d = exp_delta for _ in cutlass.range(7, unroll=1): if d < cutlass.Int32(0): scale = scale * cutlass.Float32(0.5) d = d + cutlass.Int32(1) - - # Normal value + normal_val = (cutlass.Float32(1.0) + cutlass.Float32(mantissa) / cutlass.Float32(8.0)) * scale - - # Subnormal value (e=0): val = m / 512 subnormal_val = cutlass.Float32(mantissa) / cutlass.Float32(512.0) - - # Select + result = cutlass.Float32(0.0) if exponent > cutlass.Int32(0): result = normal_val if exponent == cutlass.Int32(0): if mantissa > cutlass.Int32(0): result = subnormal_val - + return result # ── E2M1 FP4 quantization ─────────────────────────────────────────── -# E2M1 format (NVFP4 element format): -# Values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] -# Index: [0, 1, 2, 3, 4, 5, 6, 7] -# Encoded as 3-bit index + 1 sign bit = 4-bit nibble - -# Half-step lookup: quantize as round(|x| * 2) → half_step, then map to E2M1 index. -# This is equivalent to the CUDA reference approach: -# int half_step = __float2int_rn(fabsf(scaled) * 2.0f); -# int idx = half_step_to_e2m1_idx[half_step]; -# -# The half_step_to_e2m1_idx LUT (12 entries, half_step 0..11): -_HALF_STEP_TO_E2M1 = [0, 1, 2, 3, 4, 4, 5, 6, 6, 6, 7, 7] - +# E2M1: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] → indices [0..7] +# half_step LUT: [0,1,2,3,4,4,5,6,6,6,7,7] @cute.jit def quantize_e2m1_nibble( @@ -247,55 +176,21 @@ def quantize_e2m1_nibble( scale: cutlass.Float32, ) -> cutlass.Int32: """Quantize a single FP32 value to a 4-bit E2M1 nibble. - + Returns uint4 nibble: bit 3 = sign, bits [2:0] = E2M1 index. - If scale ≈ 0, returns 0 (zero nibble). - - Uses proper PTX cvt.rni for float→int conversion (round-to-nearest-even), - then LUT-based half_step → E2M1 index mapping. Matches the CUDA reference - exactly. """ nibble = cutlass.Int32(0) - + if scale > cutlass.Float32(1e-8): scaled = val / scale abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled) abs_scaled = cute.arch.fmin(abs_scaled, cutlass.Float32(6.0)) - - # half_step = round(|scaled| * 2) via PTX cvt.rni - half_step_f = abs_scaled * cutlass.Float32(2.0) - half_step = f32_to_i32_rni(half_step_f) - - # Clamp to LUT range [0, 11] - if half_step < cutlass.Int32(0): - half_step = cutlass.Int32(0) - if half_step > cutlass.Int32(11): - half_step = cutlass.Int32(11) - - # LUT: half_step → E2M1 index - # [0,1,2,3,4,4,5,6,6,6,7,7] - # Expressed as branch logic (CuTeDSL has no random-access arrays in @cute.jit) - idx = cutlass.Int32(0) - if half_step >= cutlass.Int32(1): - idx = cutlass.Int32(1) - if half_step >= cutlass.Int32(2): - idx = cutlass.Int32(2) - if half_step >= cutlass.Int32(3): - idx = cutlass.Int32(3) - if half_step >= cutlass.Int32(4): - idx = cutlass.Int32(4) - # hs=5 also maps to 4 (NOT 5) - if half_step >= cutlass.Int32(6): - idx = cutlass.Int32(5) - if half_step >= cutlass.Int32(7): - idx = cutlass.Int32(6) - # hs=8,9 also map to 6 (NOT 7,8) - if half_step >= cutlass.Int32(10): - idx = cutlass.Int32(7) - + + idx = abs_scaled_to_e2m1_idx(abs_scaled) + if scaled < cutlass.Float32(0.0): nibble = idx + cutlass.Int32(8) if scaled >= cutlass.Float32(0.0): nibble = idx - + return nibble