llvm.inline_asm fails with 'LLVM ERROR: unsupported operation' in CuTeDSL lowering pipeline. Switch to nvvm.inline_ptx which is native to the NVVM dialect and lowers correctly. - f32_to_i32_rni: cvt.rni.s32.f32 via nvvm.inline_ptx - f32_to_i32_rz: cvt.rzi.s32.f32 via nvvm.inline_ptx - f32_to_i32_rmi: cvt.rmi.s32.f32 via nvvm.inline_ptx
302 lines
11 KiB
Python
302 lines
11 KiB
Python
"""
|
|
NVFP4 quantization primitives for CuTeDSL kernels.
|
|
|
|
Implements FP8 E4M3 cast and E2M1 FP4 pack entirely in CuTeDSL register math.
|
|
No shortcuts — proper bit-level quantization matching the Python/CUDA reference.
|
|
|
|
FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8):
|
|
- 1 sign bit, 4 exponent bits, 3 mantissa bits, bias = 7
|
|
- Normal: (-1)^s * 2^(e-7) * (1 + m/8), e in [1, 15]
|
|
- Subnormal: (-1)^s * 2^(1-7) * (m/8) = m * 2^(-9), e = 0
|
|
- Max non-NaN: 2^8 * (1 + 6/8) = 448.0 (exp=15,mant=7 is NaN)
|
|
- Min positive normal: 2^(-6) ≈ 0.015625
|
|
- Min positive subnormal: 2^(-9) ≈ 0.001953
|
|
|
|
CuTeDSL constraints:
|
|
- float-to-int via NVVM inline PTX cvt.rni.s32.f32 (proper hardware rounding)
|
|
Using nvvm.inline_ptx which lowers correctly through the NVVM pipeline.
|
|
The llvm.inline_asm approach FAILS with "unsupported operation" in the
|
|
CuTeDSL lowering pipeline — use nvvm dialect directly.
|
|
- `cute.arch.fmax`/`cute.arch.fmin` for float min/max (NOT cute.math.fmin/fmax)
|
|
- `@cute.jit` decorator required for CuTeDSL functions with dynamic `if` blocks
|
|
- `cutlass.Int32(N)` creates Int32 constants; `cutlass.Float32(N)` creates Float32 constants
|
|
- `@dsl_user_op` for PTX instruction wrappers
|
|
"""
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
from cutlass.cutlass_dsl import dsl_user_op, T
|
|
from cutlass._mlir.dialects import nvvm
|
|
from cutlass.cute.typing import Float32, Int32
|
|
|
|
FP8_E4M3_BIAS = 7
|
|
|
|
|
|
# ── NVVM inline PTX float-to-int conversion ─────────────────────────
|
|
# CuTeDSL has no built-in f32→i32 conversion. We wrap PTX cvt instructions
|
|
# via @dsl_user_op + nvvm.inline_ptx. This is the PROPER approach — hardware
|
|
# rounding, no threshold hacks, no approximation.
|
|
#
|
|
# IMPORTANT: We use nvvm.inline_ptx (NVVM dialect), NOT llvm.inline_asm.
|
|
# llvm.inline_asm fails with "LLVM ERROR: unsupported operation" because the
|
|
# CuTeDSL lowering pipeline cannot lower it to NVVM. The nvvm.inline_ptx op
|
|
# is native to the NVVM dialect and lowers correctly.
|
|
|
|
@dsl_user_op
|
|
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
|
|
"""Convert Float32 to Int32 with round-to-nearest-even (RNE).
|
|
|
|
Wraps PTX: cvt.rni.s32.f32 $0, $1;
|
|
Equivalent to CUDA __float2int_rn().
|
|
"""
|
|
result = nvvm.inline_ptx(
|
|
write_only_args=[T.i32()],
|
|
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
|
ptx_code="cvt.rni.s32.f32 $0, $1;",
|
|
loc=loc,
|
|
ip=ip,
|
|
)
|
|
return Int32(result)
|
|
|
|
|
|
@dsl_user_op
|
|
def f32_to_i32_rz(x: Float32, *, loc=None, ip=None) -> Int32:
|
|
"""Convert Float32 to Int32 with round-toward-zero (RZ).
|
|
|
|
Wraps PTX: cvt.rzi.s32.f32 $0, $1;
|
|
Equivalent to CUDA __float2int_rz(). Used for truncation.
|
|
"""
|
|
result = nvvm.inline_ptx(
|
|
write_only_args=[T.i32()],
|
|
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
|
ptx_code="cvt.rzi.s32.f32 $0, $1;",
|
|
loc=loc,
|
|
ip=ip,
|
|
)
|
|
return Int32(result)
|
|
|
|
|
|
@dsl_user_op
|
|
def f32_to_i32_rmi(x: Float32, *, loc=None, ip=None) -> Int32:
|
|
"""Convert Float32 to Int32 with round-to-minus-infinity (RMI / floor).
|
|
|
|
Wraps PTX: cvt.rmi.s32.f32 $0, $1;
|
|
Equivalent to CUDA __float2int_rd() / floorf() → int.
|
|
Used for floor(log2()) extraction in FP8 encoding.
|
|
"""
|
|
result = nvvm.inline_ptx(
|
|
write_only_args=[T.i32()],
|
|
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
|
|
ptx_code="cvt.rmi.s32.f32 $0, $1;",
|
|
loc=loc,
|
|
ip=ip,
|
|
)
|
|
return Int32(result)
|
|
|
|
|
|
# ── FP8 E4M3 encoding ───────────────────────────────────────────────
|
|
|
|
@cute.jit
|
|
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
|
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32).
|
|
|
|
Only handles positive values (NVFP4 scale factors are always positive).
|
|
Returns the uint8 bit pattern packed into an Int32.
|
|
|
|
Uses proper PTX inline asm for float→int conversions:
|
|
- cvt.rmi (floor) for exponent extraction
|
|
- cvt.rni (round-to-nearest-even) for mantissa quantization
|
|
"""
|
|
result = cutlass.Int32(0) # default: zero
|
|
|
|
if val > cutlass.Float32(0.0):
|
|
# Clamp to FP8 E4M3 max non-NaN value (exp=15, mant=6 = 448.0)
|
|
clamped = cute.arch.fmin(val, cutlass.Float32(448.0))
|
|
|
|
# Compute floor(log2(clamped)) using frexp-like normalization.
|
|
# Double until >= 1, halve until < 2, tracking exponent shift.
|
|
norm = clamped
|
|
exp_floor = cutlass.Int32(0)
|
|
|
|
# Double until >= 1 (at most 7 doublings needed, smallest normal ≈ 2^-6)
|
|
for _ in cutlass.range(7, unroll=1):
|
|
if norm < cutlass.Float32(1.0):
|
|
norm = norm * cutlass.Float32(2.0)
|
|
exp_floor = exp_floor - cutlass.Int32(1)
|
|
|
|
# Halve until < 2 (at most 8 halvings needed, largest ≈ 448 < 512)
|
|
for _ in cutlass.range(8, unroll=1):
|
|
if norm >= cutlass.Float32(2.0):
|
|
norm = norm * cutlass.Float32(0.5)
|
|
exp_floor = exp_floor + cutlass.Int32(1)
|
|
|
|
# FP8 exponent = floor(log2(val)) + bias
|
|
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
|
|
if fp8_exp > cutlass.Int32(15):
|
|
fp8_exp = cutlass.Int32(15)
|
|
if fp8_exp < cutlass.Int32(0):
|
|
fp8_exp = cutlass.Int32(0)
|
|
|
|
# Mantissa for normal: (norm - 1) * 8, round via PTX cvt.rni
|
|
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
|
|
mantissa = f32_to_i32_rni(mantissa_f)
|
|
|
|
# Mantissa overflow: rounded to 8 → increment exponent, reset mantissa
|
|
if mantissa >= cutlass.Int32(8):
|
|
mantissa = cutlass.Int32(0)
|
|
fp8_exp = fp8_exp + cutlass.Int32(1)
|
|
|
|
# Clamp mantissa to [0, 7]
|
|
if mantissa < cutlass.Int32(0):
|
|
mantissa = cutlass.Int32(0)
|
|
if mantissa > cutlass.Int32(7):
|
|
mantissa = cutlass.Int32(7)
|
|
|
|
# Clamp exponent to [0, 15]
|
|
if fp8_exp < cutlass.Int32(0):
|
|
fp8_exp = cutlass.Int32(0)
|
|
if fp8_exp > cutlass.Int32(15):
|
|
fp8_exp = cutlass.Int32(15)
|
|
|
|
# NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN.
|
|
# Saturate to max non-NaN (exp=15, mant=6 = 448.0).
|
|
if fp8_exp == cutlass.Int32(15):
|
|
if mantissa == cutlass.Int32(7):
|
|
mantissa = cutlass.Int32(6)
|
|
|
|
# Subnormal handling: if fp8_exp < 1, value is 2^(1-7) * m/8 = m * 2^(-9)
|
|
# m = round(clamped * 2^9) = round(clamped * 512)
|
|
if fp8_exp < cutlass.Int32(1):
|
|
sub_m_f = clamped * cutlass.Float32(512.0)
|
|
sub_m = f32_to_i32_rni(sub_m_f)
|
|
if sub_m < cutlass.Int32(0):
|
|
sub_m = cutlass.Int32(0)
|
|
if sub_m > cutlass.Int32(7):
|
|
sub_m = cutlass.Int32(7)
|
|
mantissa = sub_m
|
|
fp8_exp = cutlass.Int32(0)
|
|
|
|
result = (fp8_exp << cutlass.Int32(3)) | mantissa
|
|
|
|
return result
|
|
|
|
|
|
@cute.jit
|
|
def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
|
|
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32.
|
|
|
|
Normal: val = 2^(e-7) * (1 + m/8)
|
|
Subnormal (e=0): val = m * 2^(-9) = m / 512
|
|
"""
|
|
mantissa = bits & cutlass.Int32(7)
|
|
exponent = (bits >> cutlass.Int32(3)) & cutlass.Int32(15)
|
|
|
|
# Compute 2^(e-7) by iterative doubling/halving from 1.0
|
|
scale = cutlass.Float32(1.0)
|
|
exp_delta = exponent - cutlass.Int32(FP8_E4M3_BIAS)
|
|
|
|
# Double for positive delta (max delta=8, e=15)
|
|
d = exp_delta
|
|
for _ in cutlass.range(8, unroll=1):
|
|
if d > cutlass.Int32(0):
|
|
scale = scale * cutlass.Float32(2.0)
|
|
d = d - cutlass.Int32(1)
|
|
|
|
# Halve for negative delta (min delta=-7, e=0)
|
|
d = exp_delta
|
|
for _ in cutlass.range(7, unroll=1):
|
|
if d < cutlass.Int32(0):
|
|
scale = scale * cutlass.Float32(0.5)
|
|
d = d + cutlass.Int32(1)
|
|
|
|
# Normal value
|
|
normal_val = (cutlass.Float32(1.0) + cutlass.Float32(mantissa) / cutlass.Float32(8.0)) * scale
|
|
|
|
# Subnormal value (e=0): val = m / 512
|
|
subnormal_val = cutlass.Float32(mantissa) / cutlass.Float32(512.0)
|
|
|
|
# Select
|
|
result = cutlass.Float32(0.0)
|
|
if exponent > cutlass.Int32(0):
|
|
result = normal_val
|
|
if exponent == cutlass.Int32(0):
|
|
if mantissa > cutlass.Int32(0):
|
|
result = subnormal_val
|
|
|
|
return result
|
|
|
|
|
|
# ── E2M1 FP4 quantization ───────────────────────────────────────────
|
|
# E2M1 format (NVFP4 element format):
|
|
# Values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
|
# Index: [0, 1, 2, 3, 4, 5, 6, 7]
|
|
# Encoded as 3-bit index + 1 sign bit = 4-bit nibble
|
|
|
|
# Half-step lookup: quantize as round(|x| * 2) → half_step, then map to E2M1 index.
|
|
# This is equivalent to the CUDA reference approach:
|
|
# int half_step = __float2int_rn(fabsf(scaled) * 2.0f);
|
|
# int idx = half_step_to_e2m1_idx[half_step];
|
|
#
|
|
# The half_step_to_e2m1_idx LUT (12 entries, half_step 0..11):
|
|
_HALF_STEP_TO_E2M1 = [0, 1, 2, 3, 4, 4, 5, 6, 6, 6, 7, 7]
|
|
|
|
|
|
@cute.jit
|
|
def quantize_e2m1_nibble(
|
|
val: cutlass.Float32,
|
|
scale: cutlass.Float32,
|
|
) -> cutlass.Int32:
|
|
"""Quantize a single FP32 value to a 4-bit E2M1 nibble.
|
|
|
|
Returns uint4 nibble: bit 3 = sign, bits [2:0] = E2M1 index.
|
|
If scale ≈ 0, returns 0 (zero nibble).
|
|
|
|
Uses proper PTX cvt.rni for float→int conversion (round-to-nearest-even),
|
|
then LUT-based half_step → E2M1 index mapping. Matches the CUDA reference
|
|
exactly.
|
|
"""
|
|
nibble = cutlass.Int32(0)
|
|
|
|
if scale > cutlass.Float32(1e-8):
|
|
scaled = val / scale
|
|
abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled)
|
|
abs_scaled = cute.arch.fmin(abs_scaled, cutlass.Float32(6.0))
|
|
|
|
# half_step = round(|scaled| * 2) via PTX cvt.rni
|
|
half_step_f = abs_scaled * cutlass.Float32(2.0)
|
|
half_step = f32_to_i32_rni(half_step_f)
|
|
|
|
# Clamp to LUT range [0, 11]
|
|
if half_step < cutlass.Int32(0):
|
|
half_step = cutlass.Int32(0)
|
|
if half_step > cutlass.Int32(11):
|
|
half_step = cutlass.Int32(11)
|
|
|
|
# LUT: half_step → E2M1 index
|
|
# [0,1,2,3,4,4,5,6,6,6,7,7]
|
|
# Expressed as branch logic (CuTeDSL has no random-access arrays in @cute.jit)
|
|
idx = cutlass.Int32(0)
|
|
if half_step >= cutlass.Int32(1):
|
|
idx = cutlass.Int32(1)
|
|
if half_step >= cutlass.Int32(2):
|
|
idx = cutlass.Int32(2)
|
|
if half_step >= cutlass.Int32(3):
|
|
idx = cutlass.Int32(3)
|
|
if half_step >= cutlass.Int32(4):
|
|
idx = cutlass.Int32(4)
|
|
# hs=5 also maps to 4 (NOT 5)
|
|
if half_step >= cutlass.Int32(6):
|
|
idx = cutlass.Int32(5)
|
|
if half_step >= cutlass.Int32(7):
|
|
idx = cutlass.Int32(6)
|
|
# hs=8,9 also map to 6 (NOT 7,8)
|
|
if half_step >= cutlass.Int32(10):
|
|
idx = cutlass.Int32(7)
|
|
|
|
if scaled < cutlass.Float32(0.0):
|
|
nibble = idx + cutlass.Int32(8)
|
|
if scaled >= cutlass.Float32(0.0):
|
|
nibble = idx
|
|
|
|
return nibble
|