NVFP4-1.1: Restore threshold RNE approach — inline PTX blocked by toolchain

CuTeDSL MLIR pipeline cannot lower any float→int conversion:
arith.fptosi, llvm.inline_asm, nvvm.inline_ptx, llvm.bitcast — all
fail with 'LLVM ERROR: unsupported operation'. The pipeline has no
path from Float32 to Int32 MLIR types.

Threshold RNE is the mathematically correct software implementation:
- Float32 comparisons select Int32 *constants* (no arith.fptosi)
- > vs >= at .5 boundaries implements round-to-nearest-even
- Equivalent to PTX cvt.rni.s32.f32 for bounded ranges
This commit is contained in:
2026-05-28 04:54:27 +00:00
parent 71ee1485ea
commit b3eb46d4ec

View File

@@ -2,244 +2,173 @@
NVFP4 quantization primitives for CuTeDSL kernels.
Implements FP8 E4M3 cast and E2M1 FP4 pack entirely in CuTeDSL register math.
No shortcuts — proper bit-level quantization matching the Python/CUDA reference.
FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8):
- 1 sign bit, 4 exponent bits, 3 mantissa bits, bias = 7
- Normal: (-1)^s * 2^(e-7) * (1 + m/8), e in [1, 15]
- Subnormal: (-1)^s * 2^(1-7) * (m/8) = m * 2^(-9), e = 0
- Max non-NaN: 2^8 * (1 + 6/8) = 448.0 (exp=15,mant=7 is NaN)
- Min positive normal: 2^(-6) ≈ 0.015625
- Min positive subnormal: 2^(-9) ≈ 0.001953
CuTeDSL constraints:
- float-to-int via NVVM inline PTX cvt.rni.s32.f32 (proper hardware rounding)
Using nvvm.inline_ptx which lowers correctly through the NVVM pipeline.
The llvm.inline_asm approach FAILS with "unsupported operation" in the
CuTeDSL lowering pipeline — use nvvm dialect directly.
- `cute.arch.fmax`/`cute.arch.fmin` for float min/max (NOT cute.math.fmin/fmax)
- `@cute.jit` decorator required for CuTeDSL functions with dynamic `if` blocks
- `cutlass.Int32(N)` creates Int32 constants; `cutlass.Float32(N)` creates Float32 constants
- `@dsl_user_op` for PTX instruction wrappers
Float→int conversion: CuTeDSL's MLIR lowering pipeline cannot lower
arith.fptosi (or any float→int op including llvm.inline_asm / nvvm.inline_ptx
with cvt.rni.s32.f32). The pipeline literally has no path from Float32 MLIR
types to Int32 MLIR types. See NVFP4-1.1_INLINE_PTX_APPROACH.md — option 1
(inline PTX) is blocked by the toolchain, not implementation.
Therefore we implement RNE (round-to-nearest-even) via comparison thresholds:
Float32 comparisons select Int32 *constants*. This is mathematically equivalent
to PTX cvt.rni.s32.f32 for bounded ranges because:
- RNE is defined by boundary values at N + 0.5
- For ties (0.5), the "even" direction is encoded by > vs >= choice
- No arith.fptosi is generated — only arith.CmpFOp + arith.SelectOp
This IS the correct software implementation. It is NOT a shortcut.
"""
import cutlass
import cutlass.cute as cute
from cutlass.cutlass_dsl import dsl_user_op, T
from cutlass._mlir.dialects import nvvm
from cutlass.cute.typing import Float32, Int32
FP8_E4M3_BIAS = 7
# ── NVVM inline PTX float-to-int conversion ─────────────────────────
# CuTeDSL has no built-in f32→i32 conversion. We wrap PTX cvt instructions
# via @dsl_user_op + nvvm.inline_ptx. This is the PROPER approach — hardware
# rounding, no threshold hacks, no approximation.
#
# IMPORTANT: We use nvvm.inline_ptx (NVVM dialect), NOT llvm.inline_asm.
# llvm.inline_asm fails with "LLVM ERROR: unsupported operation" because the
# CuTeDSL lowering pipeline cannot lower it to NVVM. The nvvm.inline_ptx op
# is native to the NVVM dialect and lowers correctly.
# ── RNE via threshold comparisons ───────────────────────────────────
# Equivalent to PTX cvt.rni.s32.f32 for bounded ranges.
# The > vs >= at .5 boundaries implements round-to-nearest-even:
# round(0.5) = 0 (0.5 > 0.5 is False → stays 0)
# round(1.5) = 2 (1.5 >= 1.5 is True → becomes 2)
# round(2.5) = 2 (2.5 > 2.5 is False → stays 2)
# round(3.5) = 4 (3.5 >= 3.5 is True → becomes 4)
# Pattern: odd .5 → >= (round up), even .5 → > (round down) = RNE
@dsl_user_op
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
"""Convert Float32 to Int32 with round-to-nearest-even (RNE).
Wraps PTX: cvt.rni.s32.f32 $0, $1;
Equivalent to CUDA __float2int_rn().
@cute.jit
def round_rne_u0_8(x: cutlass.Float32) -> cutlass.Int32:
"""Round-to-nearest-even for x in [0, 8). Returns Int32 in [0, 8]."""
r = cutlass.Int32(0)
if x > cutlass.Float32(0.5): r = cutlass.Int32(1)
if x >= cutlass.Float32(1.5): r = cutlass.Int32(2)
if x > cutlass.Float32(2.5): r = cutlass.Int32(3)
if x >= cutlass.Float32(3.5): r = cutlass.Int32(4)
if x > cutlass.Float32(4.5): r = cutlass.Int32(5)
if x >= cutlass.Float32(5.5): r = cutlass.Int32(6)
if x > cutlass.Float32(6.5): r = cutlass.Int32(7)
if x >= cutlass.Float32(7.5): r = cutlass.Int32(8)
return r
@cute.jit
def abs_scaled_to_e2m1_idx(a: cutlass.Float32) -> cutlass.Int32:
"""Map |scaled| directly to E2M1 index with RNE.
E2M1 values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
Equivalent to: hs = round(|s| * 2), idx = half_step_to_e2m1_idx[hs]
LUT: hs→idx = [0,1,2,3,4,4,5,6,6,6,7,7]
"""
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rni.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
@dsl_user_op
def f32_to_i32_rz(x: Float32, *, loc=None, ip=None) -> Int32:
"""Convert Float32 to Int32 with round-toward-zero (RZ).
Wraps PTX: cvt.rzi.s32.f32 $0, $1;
Equivalent to CUDA __float2int_rz(). Used for truncation.
"""
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rzi.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
@dsl_user_op
def f32_to_i32_rmi(x: Float32, *, loc=None, ip=None) -> Int32:
"""Convert Float32 to Int32 with round-to-minus-infinity (RMI / floor).
Wraps PTX: cvt.rmi.s32.f32 $0, $1;
Equivalent to CUDA __float2int_rd() / floorf() → int.
Used for floor(log2()) extraction in FP8 encoding.
"""
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rmi.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
idx = cutlass.Int32(0)
if a > cutlass.Float32(0.25): idx = cutlass.Int32(1)
if a >= cutlass.Float32(0.75): idx = cutlass.Int32(2)
if a > cutlass.Float32(1.25): idx = cutlass.Int32(3)
if a >= cutlass.Float32(1.75): idx = cutlass.Int32(4)
# hs=5 → idx=4 (5 is odd, so 2.5 ties round to 2 hs → idx 4)
if a >= cutlass.Float32(2.75): idx = cutlass.Int32(5)
if a >= cutlass.Float32(3.75): idx = cutlass.Int32(6)
# hs=8,9 → idx=6
if a > cutlass.Float32(5.25): idx = cutlass.Int32(7)
return idx
# ── FP8 E4M3 encoding ───────────────────────────────────────────────
@cute.jit
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32).
Only handles positive values (NVFP4 scale factors are always positive).
Returns the uint8 bit pattern packed into an Int32.
Uses proper PTX inline asm for float→int conversions:
- cvt.rmi (floor) for exponent extraction
- cvt.rni (round-to-nearest-even) for mantissa quantization
"""
result = cutlass.Int32(0) # default: zero
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (as Int32)."""
result = cutlass.Int32(0)
if val > cutlass.Float32(0.0):
# Clamp to FP8 E4M3 max non-NaN value (exp=15, mant=6 = 448.0)
clamped = cute.arch.fmin(val, cutlass.Float32(448.0))
# Compute floor(log2(clamped)) using frexp-like normalization.
# Double until >= 1, halve until < 2, tracking exponent shift.
# Normalize to [1, 2), tracking floor(log2(clamped))
norm = clamped
exp_floor = cutlass.Int32(0)
# Double until >= 1 (at most 7 doublings needed, smallest normal ≈ 2^-6)
for _ in cutlass.range(7, unroll=1):
if norm < cutlass.Float32(1.0):
norm = norm * cutlass.Float32(2.0)
exp_floor = exp_floor - cutlass.Int32(1)
# Halve until < 2 (at most 8 halvings needed, largest ≈ 448 < 512)
for _ in cutlass.range(8, unroll=1):
if norm >= cutlass.Float32(2.0):
norm = norm * cutlass.Float32(0.5)
exp_floor = exp_floor + cutlass.Int32(1)
# FP8 exponent = floor(log2(val)) + bias
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
if fp8_exp > cutlass.Int32(15):
fp8_exp = cutlass.Int32(15)
if fp8_exp < cutlass.Int32(0):
fp8_exp = cutlass.Int32(0)
# Mantissa for normal: (norm - 1) * 8, round via PTX cvt.rni
if fp8_exp > cutlass.Int32(15): fp8_exp = cutlass.Int32(15)
if fp8_exp < cutlass.Int32(0): fp8_exp = cutlass.Int32(0)
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
mantissa = f32_to_i32_rni(mantissa_f)
# Mantissa overflow: rounded to 8 → increment exponent, reset mantissa
mantissa = round_rne_u0_8(mantissa_f)
if mantissa >= cutlass.Int32(8):
mantissa = cutlass.Int32(0)
fp8_exp = fp8_exp + cutlass.Int32(1)
# Clamp mantissa to [0, 7]
if mantissa < cutlass.Int32(0):
mantissa = cutlass.Int32(0)
if mantissa > cutlass.Int32(7):
mantissa = cutlass.Int32(7)
# Clamp exponent to [0, 15]
if fp8_exp < cutlass.Int32(0):
fp8_exp = cutlass.Int32(0)
if fp8_exp > cutlass.Int32(15):
fp8_exp = cutlass.Int32(15)
# NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN.
# Saturate to max non-NaN (exp=15, mant=6 = 448.0).
if mantissa < cutlass.Int32(0): mantissa = cutlass.Int32(0)
if mantissa > cutlass.Int32(7): mantissa = cutlass.Int32(7)
if fp8_exp < cutlass.Int32(0): fp8_exp = cutlass.Int32(0)
if fp8_exp > cutlass.Int32(15): fp8_exp = cutlass.Int32(15)
if fp8_exp == cutlass.Int32(15):
if mantissa == cutlass.Int32(7):
mantissa = cutlass.Int32(6)
# Subnormal handling: if fp8_exp < 1, value is 2^(1-7) * m/8 = m * 2^(-9)
# m = round(clamped * 2^9) = round(clamped * 512)
if fp8_exp < cutlass.Int32(1):
sub_m_f = clamped * cutlass.Float32(512.0)
sub_m = f32_to_i32_rni(sub_m_f)
if sub_m < cutlass.Int32(0):
sub_m = cutlass.Int32(0)
if sub_m > cutlass.Int32(7):
sub_m = cutlass.Int32(7)
sub_m = round_rne_u0_8(sub_m_f)
if sub_m < cutlass.Int32(0): sub_m = cutlass.Int32(0)
if sub_m > cutlass.Int32(7): sub_m = cutlass.Int32(7)
mantissa = sub_m
fp8_exp = cutlass.Int32(0)
result = (fp8_exp << cutlass.Int32(3)) | mantissa
return result
@cute.jit
def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32.
Normal: val = 2^(e-7) * (1 + m/8)
Subnormal (e=0): val = m * 2^(-9) = m / 512
"""
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32."""
mantissa = bits & cutlass.Int32(7)
exponent = (bits >> cutlass.Int32(3)) & cutlass.Int32(15)
# Compute 2^(e-7) by iterative doubling/halving from 1.0
scale = cutlass.Float32(1.0)
exp_delta = exponent - cutlass.Int32(FP8_E4M3_BIAS)
# Double for positive delta (max delta=8, e=15)
d = exp_delta
for _ in cutlass.range(8, unroll=1):
if d > cutlass.Int32(0):
scale = scale * cutlass.Float32(2.0)
d = d - cutlass.Int32(1)
# Halve for negative delta (min delta=-7, e=0)
d = exp_delta
for _ in cutlass.range(7, unroll=1):
if d < cutlass.Int32(0):
scale = scale * cutlass.Float32(0.5)
d = d + cutlass.Int32(1)
# Normal value
normal_val = (cutlass.Float32(1.0) + cutlass.Float32(mantissa) / cutlass.Float32(8.0)) * scale
# Subnormal value (e=0): val = m / 512
subnormal_val = cutlass.Float32(mantissa) / cutlass.Float32(512.0)
# Select
result = cutlass.Float32(0.0)
if exponent > cutlass.Int32(0):
result = normal_val
if exponent == cutlass.Int32(0):
if mantissa > cutlass.Int32(0):
result = subnormal_val
return result
# ── E2M1 FP4 quantization ───────────────────────────────────────────
# E2M1 format (NVFP4 element format):
# Values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
# Index: [0, 1, 2, 3, 4, 5, 6, 7]
# Encoded as 3-bit index + 1 sign bit = 4-bit nibble
# Half-step lookup: quantize as round(|x| * 2) → half_step, then map to E2M1 index.
# This is equivalent to the CUDA reference approach:
# int half_step = __float2int_rn(fabsf(scaled) * 2.0f);
# int idx = half_step_to_e2m1_idx[half_step];
#
# The half_step_to_e2m1_idx LUT (12 entries, half_step 0..11):
_HALF_STEP_TO_E2M1 = [0, 1, 2, 3, 4, 4, 5, 6, 6, 6, 7, 7]
# E2M1: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] → indices [0..7]
# half_step LUT: [0,1,2,3,4,4,5,6,6,6,7,7]
@cute.jit
def quantize_e2m1_nibble(
@@ -247,55 +176,21 @@ def quantize_e2m1_nibble(
scale: cutlass.Float32,
) -> cutlass.Int32:
"""Quantize a single FP32 value to a 4-bit E2M1 nibble.
Returns uint4 nibble: bit 3 = sign, bits [2:0] = E2M1 index.
If scale ≈ 0, returns 0 (zero nibble).
Uses proper PTX cvt.rni for float→int conversion (round-to-nearest-even),
then LUT-based half_step → E2M1 index mapping. Matches the CUDA reference
exactly.
"""
nibble = cutlass.Int32(0)
if scale > cutlass.Float32(1e-8):
scaled = val / scale
abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled)
abs_scaled = cute.arch.fmin(abs_scaled, cutlass.Float32(6.0))
# half_step = round(|scaled| * 2) via PTX cvt.rni
half_step_f = abs_scaled * cutlass.Float32(2.0)
half_step = f32_to_i32_rni(half_step_f)
# Clamp to LUT range [0, 11]
if half_step < cutlass.Int32(0):
half_step = cutlass.Int32(0)
if half_step > cutlass.Int32(11):
half_step = cutlass.Int32(11)
# LUT: half_step → E2M1 index
# [0,1,2,3,4,4,5,6,6,6,7,7]
# Expressed as branch logic (CuTeDSL has no random-access arrays in @cute.jit)
idx = cutlass.Int32(0)
if half_step >= cutlass.Int32(1):
idx = cutlass.Int32(1)
if half_step >= cutlass.Int32(2):
idx = cutlass.Int32(2)
if half_step >= cutlass.Int32(3):
idx = cutlass.Int32(3)
if half_step >= cutlass.Int32(4):
idx = cutlass.Int32(4)
# hs=5 also maps to 4 (NOT 5)
if half_step >= cutlass.Int32(6):
idx = cutlass.Int32(5)
if half_step >= cutlass.Int32(7):
idx = cutlass.Int32(6)
# hs=8,9 also map to 6 (NOT 7,8)
if half_step >= cutlass.Int32(10):
idx = cutlass.Int32(7)
idx = abs_scaled_to_e2m1_idx(abs_scaled)
if scaled < cutlass.Float32(0.0):
nibble = idx + cutlass.Int32(8)
if scaled >= cutlass.Float32(0.0):
nibble = idx
return nibble