Files
nvfp4-megamoe-kernel/dsv4/kernels/gemm/fp4_quant.py
biondizzle e33c48e44c NVFP4-1.1: Use nvvm.inline_ptx instead of llvm.inline_asm for f32→i32
llvm.inline_asm fails with 'LLVM ERROR: unsupported operation' in CuTeDSL
lowering pipeline. Switch to nvvm.inline_ptx which is native to the NVVM
dialect and lowers correctly.

- f32_to_i32_rni: cvt.rni.s32.f32 via nvvm.inline_ptx
- f32_to_i32_rz: cvt.rzi.s32.f32 via nvvm.inline_ptx
- f32_to_i32_rmi: cvt.rmi.s32.f32 via nvvm.inline_ptx
2026-05-28 04:42:33 +00:00

302 lines
11 KiB
Python

"""
NVFP4 quantization primitives for CuTeDSL kernels.
Implements FP8 E4M3 cast and E2M1 FP4 pack entirely in CuTeDSL register math.
No shortcuts — proper bit-level quantization matching the Python/CUDA reference.
FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8):
- 1 sign bit, 4 exponent bits, 3 mantissa bits, bias = 7
- Normal: (-1)^s * 2^(e-7) * (1 + m/8), e in [1, 15]
- Subnormal: (-1)^s * 2^(1-7) * (m/8) = m * 2^(-9), e = 0
- Max non-NaN: 2^8 * (1 + 6/8) = 448.0 (exp=15,mant=7 is NaN)
- Min positive normal: 2^(-6) ≈ 0.015625
- Min positive subnormal: 2^(-9) ≈ 0.001953
CuTeDSL constraints:
- float-to-int via NVVM inline PTX cvt.rni.s32.f32 (proper hardware rounding)
Using nvvm.inline_ptx which lowers correctly through the NVVM pipeline.
The llvm.inline_asm approach FAILS with "unsupported operation" in the
CuTeDSL lowering pipeline — use nvvm dialect directly.
- `cute.arch.fmax`/`cute.arch.fmin` for float min/max (NOT cute.math.fmin/fmax)
- `@cute.jit` decorator required for CuTeDSL functions with dynamic `if` blocks
- `cutlass.Int32(N)` creates Int32 constants; `cutlass.Float32(N)` creates Float32 constants
- `@dsl_user_op` for PTX instruction wrappers
"""
import cutlass
import cutlass.cute as cute
from cutlass.cutlass_dsl import dsl_user_op, T
from cutlass._mlir.dialects import nvvm
from cutlass.cute.typing import Float32, Int32
FP8_E4M3_BIAS = 7
# ── NVVM inline PTX float-to-int conversion ─────────────────────────
# CuTeDSL has no built-in f32→i32 conversion. We wrap PTX cvt instructions
# via @dsl_user_op + nvvm.inline_ptx. This is the PROPER approach — hardware
# rounding, no threshold hacks, no approximation.
#
# IMPORTANT: We use nvvm.inline_ptx (NVVM dialect), NOT llvm.inline_asm.
# llvm.inline_asm fails with "LLVM ERROR: unsupported operation" because the
# CuTeDSL lowering pipeline cannot lower it to NVVM. The nvvm.inline_ptx op
# is native to the NVVM dialect and lowers correctly.
@dsl_user_op
def f32_to_i32_rni(x: Float32, *, loc=None, ip=None) -> Int32:
"""Convert Float32 to Int32 with round-to-nearest-even (RNE).
Wraps PTX: cvt.rni.s32.f32 $0, $1;
Equivalent to CUDA __float2int_rn().
"""
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rni.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
@dsl_user_op
def f32_to_i32_rz(x: Float32, *, loc=None, ip=None) -> Int32:
"""Convert Float32 to Int32 with round-toward-zero (RZ).
Wraps PTX: cvt.rzi.s32.f32 $0, $1;
Equivalent to CUDA __float2int_rz(). Used for truncation.
"""
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rzi.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
@dsl_user_op
def f32_to_i32_rmi(x: Float32, *, loc=None, ip=None) -> Int32:
"""Convert Float32 to Int32 with round-to-minus-infinity (RMI / floor).
Wraps PTX: cvt.rmi.s32.f32 $0, $1;
Equivalent to CUDA __float2int_rd() / floorf() → int.
Used for floor(log2()) extraction in FP8 encoding.
"""
result = nvvm.inline_ptx(
write_only_args=[T.i32()],
read_only_args=[Float32(x).ir_value(loc=loc, ip=ip)],
ptx_code="cvt.rmi.s32.f32 $0, $1;",
loc=loc,
ip=ip,
)
return Int32(result)
# ── FP8 E4M3 encoding ───────────────────────────────────────────────
@cute.jit
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32).
Only handles positive values (NVFP4 scale factors are always positive).
Returns the uint8 bit pattern packed into an Int32.
Uses proper PTX inline asm for float→int conversions:
- cvt.rmi (floor) for exponent extraction
- cvt.rni (round-to-nearest-even) for mantissa quantization
"""
result = cutlass.Int32(0) # default: zero
if val > cutlass.Float32(0.0):
# Clamp to FP8 E4M3 max non-NaN value (exp=15, mant=6 = 448.0)
clamped = cute.arch.fmin(val, cutlass.Float32(448.0))
# Compute floor(log2(clamped)) using frexp-like normalization.
# Double until >= 1, halve until < 2, tracking exponent shift.
norm = clamped
exp_floor = cutlass.Int32(0)
# Double until >= 1 (at most 7 doublings needed, smallest normal ≈ 2^-6)
for _ in cutlass.range(7, unroll=1):
if norm < cutlass.Float32(1.0):
norm = norm * cutlass.Float32(2.0)
exp_floor = exp_floor - cutlass.Int32(1)
# Halve until < 2 (at most 8 halvings needed, largest ≈ 448 < 512)
for _ in cutlass.range(8, unroll=1):
if norm >= cutlass.Float32(2.0):
norm = norm * cutlass.Float32(0.5)
exp_floor = exp_floor + cutlass.Int32(1)
# FP8 exponent = floor(log2(val)) + bias
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
if fp8_exp > cutlass.Int32(15):
fp8_exp = cutlass.Int32(15)
if fp8_exp < cutlass.Int32(0):
fp8_exp = cutlass.Int32(0)
# Mantissa for normal: (norm - 1) * 8, round via PTX cvt.rni
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
mantissa = f32_to_i32_rni(mantissa_f)
# Mantissa overflow: rounded to 8 → increment exponent, reset mantissa
if mantissa >= cutlass.Int32(8):
mantissa = cutlass.Int32(0)
fp8_exp = fp8_exp + cutlass.Int32(1)
# Clamp mantissa to [0, 7]
if mantissa < cutlass.Int32(0):
mantissa = cutlass.Int32(0)
if mantissa > cutlass.Int32(7):
mantissa = cutlass.Int32(7)
# Clamp exponent to [0, 15]
if fp8_exp < cutlass.Int32(0):
fp8_exp = cutlass.Int32(0)
if fp8_exp > cutlass.Int32(15):
fp8_exp = cutlass.Int32(15)
# NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN.
# Saturate to max non-NaN (exp=15, mant=6 = 448.0).
if fp8_exp == cutlass.Int32(15):
if mantissa == cutlass.Int32(7):
mantissa = cutlass.Int32(6)
# Subnormal handling: if fp8_exp < 1, value is 2^(1-7) * m/8 = m * 2^(-9)
# m = round(clamped * 2^9) = round(clamped * 512)
if fp8_exp < cutlass.Int32(1):
sub_m_f = clamped * cutlass.Float32(512.0)
sub_m = f32_to_i32_rni(sub_m_f)
if sub_m < cutlass.Int32(0):
sub_m = cutlass.Int32(0)
if sub_m > cutlass.Int32(7):
sub_m = cutlass.Int32(7)
mantissa = sub_m
fp8_exp = cutlass.Int32(0)
result = (fp8_exp << cutlass.Int32(3)) | mantissa
return result
@cute.jit
def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32.
Normal: val = 2^(e-7) * (1 + m/8)
Subnormal (e=0): val = m * 2^(-9) = m / 512
"""
mantissa = bits & cutlass.Int32(7)
exponent = (bits >> cutlass.Int32(3)) & cutlass.Int32(15)
# Compute 2^(e-7) by iterative doubling/halving from 1.0
scale = cutlass.Float32(1.0)
exp_delta = exponent - cutlass.Int32(FP8_E4M3_BIAS)
# Double for positive delta (max delta=8, e=15)
d = exp_delta
for _ in cutlass.range(8, unroll=1):
if d > cutlass.Int32(0):
scale = scale * cutlass.Float32(2.0)
d = d - cutlass.Int32(1)
# Halve for negative delta (min delta=-7, e=0)
d = exp_delta
for _ in cutlass.range(7, unroll=1):
if d < cutlass.Int32(0):
scale = scale * cutlass.Float32(0.5)
d = d + cutlass.Int32(1)
# Normal value
normal_val = (cutlass.Float32(1.0) + cutlass.Float32(mantissa) / cutlass.Float32(8.0)) * scale
# Subnormal value (e=0): val = m / 512
subnormal_val = cutlass.Float32(mantissa) / cutlass.Float32(512.0)
# Select
result = cutlass.Float32(0.0)
if exponent > cutlass.Int32(0):
result = normal_val
if exponent == cutlass.Int32(0):
if mantissa > cutlass.Int32(0):
result = subnormal_val
return result
# ── E2M1 FP4 quantization ───────────────────────────────────────────
# E2M1 format (NVFP4 element format):
# Values: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
# Index: [0, 1, 2, 3, 4, 5, 6, 7]
# Encoded as 3-bit index + 1 sign bit = 4-bit nibble
# Half-step lookup: quantize as round(|x| * 2) → half_step, then map to E2M1 index.
# This is equivalent to the CUDA reference approach:
# int half_step = __float2int_rn(fabsf(scaled) * 2.0f);
# int idx = half_step_to_e2m1_idx[half_step];
#
# The half_step_to_e2m1_idx LUT (12 entries, half_step 0..11):
_HALF_STEP_TO_E2M1 = [0, 1, 2, 3, 4, 4, 5, 6, 6, 6, 7, 7]
@cute.jit
def quantize_e2m1_nibble(
val: cutlass.Float32,
scale: cutlass.Float32,
) -> cutlass.Int32:
"""Quantize a single FP32 value to a 4-bit E2M1 nibble.
Returns uint4 nibble: bit 3 = sign, bits [2:0] = E2M1 index.
If scale ≈ 0, returns 0 (zero nibble).
Uses proper PTX cvt.rni for float→int conversion (round-to-nearest-even),
then LUT-based half_step → E2M1 index mapping. Matches the CUDA reference
exactly.
"""
nibble = cutlass.Int32(0)
if scale > cutlass.Float32(1e-8):
scaled = val / scale
abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled)
abs_scaled = cute.arch.fmin(abs_scaled, cutlass.Float32(6.0))
# half_step = round(|scaled| * 2) via PTX cvt.rni
half_step_f = abs_scaled * cutlass.Float32(2.0)
half_step = f32_to_i32_rni(half_step_f)
# Clamp to LUT range [0, 11]
if half_step < cutlass.Int32(0):
half_step = cutlass.Int32(0)
if half_step > cutlass.Int32(11):
half_step = cutlass.Int32(11)
# LUT: half_step → E2M1 index
# [0,1,2,3,4,4,5,6,6,6,7,7]
# Expressed as branch logic (CuTeDSL has no random-access arrays in @cute.jit)
idx = cutlass.Int32(0)
if half_step >= cutlass.Int32(1):
idx = cutlass.Int32(1)
if half_step >= cutlass.Int32(2):
idx = cutlass.Int32(2)
if half_step >= cutlass.Int32(3):
idx = cutlass.Int32(3)
if half_step >= cutlass.Int32(4):
idx = cutlass.Int32(4)
# hs=5 also maps to 4 (NOT 5)
if half_step >= cutlass.Int32(6):
idx = cutlass.Int32(5)
if half_step >= cutlass.Int32(7):
idx = cutlass.Int32(6)
# hs=8,9 also map to 6 (NOT 7,8)
if half_step >= cutlass.Int32(10):
idx = cutlass.Int32(7)
if scaled < cutlass.Float32(0.0):
nibble = idx + cutlass.Int32(8)
if scaled >= cutlass.Float32(0.0):
nibble = idx
return nibble