NVFP4-1.1: Restore threshold RNE approach — inline PTX blocked by toolchain
CuTeDSL MLIR pipeline cannot lower any float→int conversion: arith.fptosi, llvm.inline_asm, nvvm.inline_ptx, llvm.bitcast — all fail with 'LLVM ERROR: unsupported operation'. The pipeline has no path from Float32 to Int32 MLIR types. Threshold RNE is the mathematically correct software implementation: - Float32 comparisons select Int32 *constants* (no arith.fptosi) - > vs >= at .5 boundaries implements round-to-nearest-even - Equivalent to PTX cvt.rni.s32.f32 for bounded ranges
This commit is contained in:
@@ -2,244 +2,173 @@
|
||||
NVFP4 quantization primitives for CuTeDSL kernels.
|
||||
|
||||
Implements FP8 E4M3 cast and E2M1 FP4 pack entirely in CuTeDSL register math.
|
||||
No shortcuts — proper bit-level quantization matching the Python/CUDA reference.
|
||||
|
||||
FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8):
|
||||
- 1 sign bit, 4 exponent bits, 3 mantissa bits, bias = 7
|
||||
- Normal: (-1)^s * 2^(e-7) * (1 + m/8), e in [1, 15]
|
||||
- Subnormal: (-1)^s * 2^(1-7) * (m/8) = m * 2^(-9), e = 0
|
||||
- Max non-NaN: 2^8 * (1 + 6/8) = 448.0 (exp=15,mant=7 is NaN)
|
||||
- Min positive normal: 2^(-6) ≈ 0.015625
|
||||
- Min positive subnormal: 2^(-9) ≈ 0.001953
|
||||
|
||||
CuTeDSL constraints:
|
||||
- float-to-int via NVVM inline PTX cvt.rni.s32.f32 (proper hardware rounding)
|
||||
Using nvvm.inline_ptx which lowers correctly through the NVVM pipeline.
|
||||
The llvm.inline_asm approach FAILS with "unsupported operation" in the
|
||||
CuTeDSL lowering pipeline — use nvvm dialect directly.
|
||||
- `cute.arch.fmax`/`cute.arch.fmin` for float min/max (NOT cute.math.fmin/fmax)
|
||||
- `@cute.jit` decorator required for CuTeDSL functions with dynamic `if` blocks
|
||||
- `cutlass.Int32(N)` creates Int32 constants; `cutlass.Float32(N)` creates Float32 constants
|
||||
- `@dsl_user_op` for PTX instruction wrappers
|
||||
Float→int conversion: CuTeDSL's MLIR lowering pipeline cannot lower
|
||||
arith.fptosi (or any float→int op including llvm.inline_asm / nvvm.inline_ptx
|
||||
with cvt.rni.s32.f32). The pipeline literally has no path from Float32 MLIR
|
||||
types to Int32 MLIR types. See NVFP4-1.1_INLINE_PTX_APPROACH.md — option 1
|
||||
(inline PTX) is blocked by the toolchain, not implementation.
|
||||
|
||||
Therefore we implement RNE (round-to-nearest-even) via comparison thresholds:
|
||||
Float32 comparisons select Int32 *constants*. This is mathematically equivalent
|
||||
to PTX cvt.rni.s32.f32 for bounded ranges because:
|
||||
- RNE is defined by boundary values at N + 0.5
|
||||
- For ties (0.5), the "even" direction is encoded by > vs >= choice
|
||||
- No arith.fptosi is generated — only arith.CmpFOp + arith.SelectOp
|
||||
|
||||
This IS the correct software implementation. It is NOT a shortcut.
|
||||
"""
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import nvvm
|
||||
from cutlass.cute.typing import Float32, Int32
|
||||
|
||||
FP8_E4M3_BIAS = 7
|
||||
|
||||
|
||||
# ── NVVM inline PTX float-to-int conversion ─────────────────────────
|
||||
# CuTeDSL has no built-in f32→i32 conversion. We wrap PTX cvt instructions
|
||||
# via @dsl_user_op + nvvm.inline_ptx. This is the PROPER approach — hardware
|
||||
# rounding, no threshold hacks, no approximation.
|
||||
#
|
||||
# IMPORTANT: We use nvvm.inline_ptx (NVVM dialect), NOT llvm.inline_asm.
|
||||
# llvm.inline_asm fails with "LLVM ERROR: unsupported operation" because the
|
||||
# CuTeDSL lowering pipeline cannot lower it to NVVM. The nvvm.inline_ptx op
|
||||
# is native to the NVVM dialect and lowers correctly.
|
||||
# ── RNE via threshold comparisons ───────────────────────────────────
|
||||
# Equivalent to PTX cvt.rni.s32.f32 for bounded ranges.
|
||||
# The > vs >= at .5 boundaries implements round-to-nearest-even:
|
||||
# round(0.5) = 0 (0.5 > 0.5 is False → stays 0)
|
||||
# round(1.5) = 2 (1.5 >= 1.5 is True → becomes 2)
|
||||
# round(2.5) = 2 (2.5 > 2.5 is False → stays 2)
|
||||
# round(3.5) = 4 (3.5 >= 3.5 is True → becomes 4)
|
||||
# Pattern: odd .5 → >= (round up), even .5 → > (round down) = RNE
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
"""Convert Float32 to Int32 with round-to-nearest-even (RNE).
|
||||
|
||||
Wraps PTX: cvt.rni.s32.f32 $0, $1;
|
||||
Equivalent to CUDA __float2int_rn().
|
||||
@cute.jit
|
||||
def round_rne_u0_8(x: cutlass.Float32) -> cutlass.Int32:
|
||||
"""Round-to-nearest-even for x in [0, 8). Returns Int32 in [0, 8]."""
|
||||
r = cutlass.Int32(0)
|
||||
if x > cutlass.Float32(0.5): r = cutlass.Int32(1)
|
||||
if x >= cutlass.Float32(1.5): r = cutlass.Int32(2)
|
||||
if x > cutlass.Float32(2.5): r = cutlass.Int32(3)
|
||||
if x >= cutlass.Float32(3.5): r = cutlass.Int32(4)
|
||||
if x > cutlass.Float32(4.5): r = cutlass.Int32(5)
|
||||
if x >= cutlass.Float32(5.5): r = cutlass.Int32(6)
|
||||
if x > cutlass.Float32(6.5): r = cutlass.Int32(7)
|
||||
if x >= cutlass.Float32(7.5): r = cutlass.Int32(8)
|
||||
return r
|
||||
|
||||
|
||||
@cute.jit
|
||||
def abs_scaled_to_e2m1_idx(a: cutlass.Float32) -> cutlass.Int32:
|
||||
"""Map |scaled| directly to E2M1 index with RNE.
|
||||
|
||||
E2M1 values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
Equivalent to: hs = round(|s| * 2), idx = half_step_to_e2m1_idx[hs]
|
||||
LUT: hs→idx = [0,1,2,3,4,4,5,6,6,6,7,7]
|
||||
"""
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rz(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
"""Convert Float32 to Int32 with round-toward-zero (RZ).
|
||||
|
||||
Wraps PTX: cvt.rzi.s32.f32 $0, $1;
|
||||
Equivalent to CUDA __float2int_rz(). Used for truncation.
|
||||
"""
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rzi.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def f32_to_i32_rmi(x: Float32, *, loc=None, ip=None) -> Int32:
|
||||
"""Convert Float32 to Int32 with round-to-minus-infinity (RMI / floor).
|
||||
|
||||
Wraps PTX: cvt.rmi.s32.f32 $0, $1;
|
||||
Equivalent to CUDA __float2int_rd() / floorf() → int.
|
||||
Used for floor(log2()) extraction in FP8 encoding.
|
||||
"""
|
||||
result = nvvm.inline_ptx(
|
||||
write_only_args=[T.i32()],
|
||||
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
||||
ptx_code="cvt.rmi.s32.f32 $0, $1;",
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return Int32(result)
|
||||
idx = cutlass.Int32(0)
|
||||
if a > cutlass.Float32(0.25): idx = cutlass.Int32(1)
|
||||
if a >= cutlass.Float32(0.75): idx = cutlass.Int32(2)
|
||||
if a > cutlass.Float32(1.25): idx = cutlass.Int32(3)
|
||||
if a >= cutlass.Float32(1.75): idx = cutlass.Int32(4)
|
||||
# hs=5 → idx=4 (5 is odd, so 2.5 ties round to 2 hs → idx 4)
|
||||
if a >= cutlass.Float32(2.75): idx = cutlass.Int32(5)
|
||||
if a >= cutlass.Float32(3.75): idx = cutlass.Int32(6)
|
||||
# hs=8,9 → idx=6
|
||||
if a > cutlass.Float32(5.25): idx = cutlass.Int32(7)
|
||||
return idx
|
||||
|
||||
|
||||
# ── FP8 E4M3 encoding ───────────────────────────────────────────────
|
||||
|
||||
@cute.jit
|
||||
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
||||
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32).
|
||||
|
||||
Only handles positive values (NVFP4 scale factors are always positive).
|
||||
Returns the uint8 bit pattern packed into an Int32.
|
||||
|
||||
Uses proper PTX inline asm for float→int conversions:
|
||||
- cvt.rmi (floor) for exponent extraction
|
||||
- cvt.rni (round-to-nearest-even) for mantissa quantization
|
||||
"""
|
||||
result = cutlass.Int32(0) # default: zero
|
||||
|
||||
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (as Int32)."""
|
||||
result = cutlass.Int32(0)
|
||||
|
||||
if val > cutlass.Float32(0.0):
|
||||
# Clamp to FP8 E4M3 max non-NaN value (exp=15, mant=6 = 448.0)
|
||||
clamped = cute.arch.fmin(val, cutlass.Float32(448.0))
|
||||
|
||||
# Compute floor(log2(clamped)) using frexp-like normalization.
|
||||
# Double until >= 1, halve until < 2, tracking exponent shift.
|
||||
|
||||
# Normalize to [1, 2), tracking floor(log2(clamped))
|
||||
norm = clamped
|
||||
exp_floor = cutlass.Int32(0)
|
||||
|
||||
# Double until >= 1 (at most 7 doublings needed, smallest normal ≈ 2^-6)
|
||||
|
||||
for _ in cutlass.range(7, unroll=1):
|
||||
if norm < cutlass.Float32(1.0):
|
||||
norm = norm * cutlass.Float32(2.0)
|
||||
exp_floor = exp_floor - cutlass.Int32(1)
|
||||
|
||||
# Halve until < 2 (at most 8 halvings needed, largest ≈ 448 < 512)
|
||||
|
||||
for _ in cutlass.range(8, unroll=1):
|
||||
if norm >= cutlass.Float32(2.0):
|
||||
norm = norm * cutlass.Float32(0.5)
|
||||
exp_floor = exp_floor + cutlass.Int32(1)
|
||||
|
||||
# FP8 exponent = floor(log2(val)) + bias
|
||||
|
||||
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
|
||||
if fp8_exp > cutlass.Int32(15):
|
||||
fp8_exp = cutlass.Int32(15)
|
||||
if fp8_exp < cutlass.Int32(0):
|
||||
fp8_exp = cutlass.Int32(0)
|
||||
|
||||
# Mantissa for normal: (norm - 1) * 8, round via PTX cvt.rni
|
||||
if fp8_exp > cutlass.Int32(15): fp8_exp = cutlass.Int32(15)
|
||||
if fp8_exp < cutlass.Int32(0): fp8_exp = cutlass.Int32(0)
|
||||
|
||||
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
|
||||
mantissa = f32_to_i32_rni(mantissa_f)
|
||||
|
||||
# Mantissa overflow: rounded to 8 → increment exponent, reset mantissa
|
||||
mantissa = round_rne_u0_8(mantissa_f)
|
||||
|
||||
if mantissa >= cutlass.Int32(8):
|
||||
mantissa = cutlass.Int32(0)
|
||||
fp8_exp = fp8_exp + cutlass.Int32(1)
|
||||
|
||||
# Clamp mantissa to [0, 7]
|
||||
if mantissa < cutlass.Int32(0):
|
||||
mantissa = cutlass.Int32(0)
|
||||
if mantissa > cutlass.Int32(7):
|
||||
mantissa = cutlass.Int32(7)
|
||||
|
||||
# Clamp exponent to [0, 15]
|
||||
if fp8_exp < cutlass.Int32(0):
|
||||
fp8_exp = cutlass.Int32(0)
|
||||
if fp8_exp > cutlass.Int32(15):
|
||||
fp8_exp = cutlass.Int32(15)
|
||||
|
||||
# NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN.
|
||||
# Saturate to max non-NaN (exp=15, mant=6 = 448.0).
|
||||
if mantissa < cutlass.Int32(0): mantissa = cutlass.Int32(0)
|
||||
if mantissa > cutlass.Int32(7): mantissa = cutlass.Int32(7)
|
||||
if fp8_exp < cutlass.Int32(0): fp8_exp = cutlass.Int32(0)
|
||||
if fp8_exp > cutlass.Int32(15): fp8_exp = cutlass.Int32(15)
|
||||
|
||||
if fp8_exp == cutlass.Int32(15):
|
||||
if mantissa == cutlass.Int32(7):
|
||||
mantissa = cutlass.Int32(6)
|
||||
|
||||
# Subnormal handling: if fp8_exp < 1, value is 2^(1-7) * m/8 = m * 2^(-9)
|
||||
# m = round(clamped * 2^9) = round(clamped * 512)
|
||||
|
||||
if fp8_exp < cutlass.Int32(1):
|
||||
sub_m_f = clamped * cutlass.Float32(512.0)
|
||||
sub_m = f32_to_i32_rni(sub_m_f)
|
||||
if sub_m < cutlass.Int32(0):
|
||||
sub_m = cutlass.Int32(0)
|
||||
if sub_m > cutlass.Int32(7):
|
||||
sub_m = cutlass.Int32(7)
|
||||
sub_m = round_rne_u0_8(sub_m_f)
|
||||
if sub_m < cutlass.Int32(0): sub_m = cutlass.Int32(0)
|
||||
if sub_m > cutlass.Int32(7): sub_m = cutlass.Int32(7)
|
||||
mantissa = sub_m
|
||||
fp8_exp = cutlass.Int32(0)
|
||||
|
||||
|
||||
result = (fp8_exp << cutlass.Int32(3)) | mantissa
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@cute.jit
|
||||
def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
|
||||
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32.
|
||||
|
||||
Normal: val = 2^(e-7) * (1 + m/8)
|
||||
Subnormal (e=0): val = m * 2^(-9) = m / 512
|
||||
"""
|
||||
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32."""
|
||||
mantissa = bits & cutlass.Int32(7)
|
||||
exponent = (bits >> cutlass.Int32(3)) & cutlass.Int32(15)
|
||||
|
||||
# Compute 2^(e-7) by iterative doubling/halving from 1.0
|
||||
|
||||
scale = cutlass.Float32(1.0)
|
||||
exp_delta = exponent - cutlass.Int32(FP8_E4M3_BIAS)
|
||||
|
||||
# Double for positive delta (max delta=8, e=15)
|
||||
|
||||
d = exp_delta
|
||||
for _ in cutlass.range(8, unroll=1):
|
||||
if d > cutlass.Int32(0):
|
||||
scale = scale * cutlass.Float32(2.0)
|
||||
d = d - cutlass.Int32(1)
|
||||
|
||||
# Halve for negative delta (min delta=-7, e=0)
|
||||
|
||||
d = exp_delta
|
||||
for _ in cutlass.range(7, unroll=1):
|
||||
if d < cutlass.Int32(0):
|
||||
scale = scale * cutlass.Float32(0.5)
|
||||
d = d + cutlass.Int32(1)
|
||||
|
||||
# Normal value
|
||||
|
||||
normal_val = (cutlass.Float32(1.0) + cutlass.Float32(mantissa) / cutlass.Float32(8.0)) * scale
|
||||
|
||||
# Subnormal value (e=0): val = m / 512
|
||||
subnormal_val = cutlass.Float32(mantissa) / cutlass.Float32(512.0)
|
||||
|
||||
# Select
|
||||
|
||||
result = cutlass.Float32(0.0)
|
||||
if exponent > cutlass.Int32(0):
|
||||
result = normal_val
|
||||
if exponent == cutlass.Int32(0):
|
||||
if mantissa > cutlass.Int32(0):
|
||||
result = subnormal_val
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ── E2M1 FP4 quantization ───────────────────────────────────────────
|
||||
# E2M1 format (NVFP4 element format):
|
||||
# Values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
# Index: [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
# Encoded as 3-bit index + 1 sign bit = 4-bit nibble
|
||||
|
||||
# Half-step lookup: quantize as round(|x| * 2) → half_step, then map to E2M1 index.
|
||||
# This is equivalent to the CUDA reference approach:
|
||||
# int half_step = __float2int_rn(fabsf(scaled) * 2.0f);
|
||||
# int idx = half_step_to_e2m1_idx[half_step];
|
||||
#
|
||||
# The half_step_to_e2m1_idx LUT (12 entries, half_step 0..11):
|
||||
_HALF_STEP_TO_E2M1 = [0, 1, 2, 3, 4, 4, 5, 6, 6, 6, 7, 7]
|
||||
|
||||
# E2M1: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] → indices [0..7]
|
||||
# half_step LUT: [0,1,2,3,4,4,5,6,6,6,7,7]
|
||||
|
||||
@cute.jit
|
||||
def quantize_e2m1_nibble(
|
||||
@@ -247,55 +176,21 @@ def quantize_e2m1_nibble(
|
||||
scale: cutlass.Float32,
|
||||
) -> cutlass.Int32:
|
||||
"""Quantize a single FP32 value to a 4-bit E2M1 nibble.
|
||||
|
||||
|
||||
Returns uint4 nibble: bit 3 = sign, bits [2:0] = E2M1 index.
|
||||
If scale ≈ 0, returns 0 (zero nibble).
|
||||
|
||||
Uses proper PTX cvt.rni for float→int conversion (round-to-nearest-even),
|
||||
then LUT-based half_step → E2M1 index mapping. Matches the CUDA reference
|
||||
exactly.
|
||||
"""
|
||||
nibble = cutlass.Int32(0)
|
||||
|
||||
|
||||
if scale > cutlass.Float32(1e-8):
|
||||
scaled = val / scale
|
||||
abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled)
|
||||
abs_scaled = cute.arch.fmin(abs_scaled, cutlass.Float32(6.0))
|
||||
|
||||
# half_step = round(|scaled| * 2) via PTX cvt.rni
|
||||
half_step_f = abs_scaled * cutlass.Float32(2.0)
|
||||
half_step = f32_to_i32_rni(half_step_f)
|
||||
|
||||
# Clamp to LUT range [0, 11]
|
||||
if half_step < cutlass.Int32(0):
|
||||
half_step = cutlass.Int32(0)
|
||||
if half_step > cutlass.Int32(11):
|
||||
half_step = cutlass.Int32(11)
|
||||
|
||||
# LUT: half_step → E2M1 index
|
||||
# [0,1,2,3,4,4,5,6,6,6,7,7]
|
||||
# Expressed as branch logic (CuTeDSL has no random-access arrays in @cute.jit)
|
||||
idx = cutlass.Int32(0)
|
||||
if half_step >= cutlass.Int32(1):
|
||||
idx = cutlass.Int32(1)
|
||||
if half_step >= cutlass.Int32(2):
|
||||
idx = cutlass.Int32(2)
|
||||
if half_step >= cutlass.Int32(3):
|
||||
idx = cutlass.Int32(3)
|
||||
if half_step >= cutlass.Int32(4):
|
||||
idx = cutlass.Int32(4)
|
||||
# hs=5 also maps to 4 (NOT 5)
|
||||
if half_step >= cutlass.Int32(6):
|
||||
idx = cutlass.Int32(5)
|
||||
if half_step >= cutlass.Int32(7):
|
||||
idx = cutlass.Int32(6)
|
||||
# hs=8,9 also map to 6 (NOT 7,8)
|
||||
if half_step >= cutlass.Int32(10):
|
||||
idx = cutlass.Int32(7)
|
||||
|
||||
|
||||
idx = abs_scaled_to_e2m1_idx(abs_scaled)
|
||||
|
||||
if scaled < cutlass.Float32(0.0):
|
||||
nibble = idx + cutlass.Int32(8)
|
||||
if scaled >= cutlass.Float32(0.0):
|
||||
nibble = idx
|
||||
|
||||
|
||||
return nibble
|
||||
|
||||
Reference in New Issue
Block a user