212 lines
8.1 KiB
Python
212 lines
8.1 KiB
Python
"""
|
|
NVFP4 quantization primitives for CuTeDSL kernels.
|
|
|
|
Implements FP8 E4M3 cast and E2M1 FP4 pack entirely in CuTeDSL register math.
|
|
No shortcuts — proper bit-level quantization matching the Python/CUDA reference.
|
|
|
|
FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8):
|
|
- 1 sign bit, 4 exponent bits, 3 mantissa bits, bias = 7
|
|
- Normal: (-1)^s * 2^(e-7) * (1 + m/8), e in [1, 14]
|
|
- Subnormal: (-1)^s * 2^(1-7) * (m/8) = m * 2^(-9), e = 0
|
|
- Max normal: 2^8 * (1 + 6/8) = 448.0 (exp=15,mant=7 is NaN; exp=15,mant=0-6 are valid)
|
|
- Min positive normal: 2^(-6) ≈ 0.015625
|
|
- Min positive subnormal: 2^(-9) ≈ 0.001953
|
|
|
|
NVFP4 format:
|
|
- 16-element microblocks
|
|
- FP8 E4M3 block scale: amax / 6 (max E2M1 magnitude = 6)
|
|
- Per-element E2M1 quantize: nearest of {0, 0.5, 1, 1.5, 2, 3, 4, 6}
|
|
- Two 4-bit nibbles packed into one uint8 byte: (odd << 4) | even
|
|
|
|
CuTeDSL constraints:
|
|
- Variables defined before `if` blocks can be reassigned inside and read after.
|
|
- Both branches of `if` are compiled; use `cutlass.const_expr` to eliminate dead code.
|
|
- `range(unroll=1)` produces runtime loops (not unrolled at trace time).
|
|
- No log2, frexp, bit_cast, or reinterpret_cast for scalars.
|
|
"""
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
|
|
FP8_E4M3_BIAS = 7
|
|
|
|
|
|
def half_step_to_e2m1_idx(hs: cutlass.Int32) -> cutlass.Int32:
|
|
"""Map half-step value (0-12) to E2M1 index (0-7).
|
|
|
|
Matches the CUDA kernel's half_step_to_e4m3() and the Python LUT:
|
|
0→0, 1→1, 2→2, 3→3, 4→4, 5→4, 6→5, 7→5, 8→6, 9→6, 10→6, 11→7, 12→7
|
|
"""
|
|
result = cutlass.Int32(7) # default for 11, 12
|
|
if hs < cutlass.Int32(5):
|
|
if hs < cutlass.Int32(4):
|
|
result = hs # 0, 1, 2,3 → identity
|
|
if hs >= cutlass.Int32(4):
|
|
result = cutlass.Int32(4) # 4 → 4
|
|
if hs >= cutlass.Int32(5):
|
|
if hs < cutlass.Int32(8):
|
|
if hs < cutlass.Int32(6):
|
|
result = cutlass.Int32(4) # 5 → 4
|
|
if hs >= cutlass.Int32(6):
|
|
result = cutlass.Int32(5) # 6, 7 → 5
|
|
if hs >= cutlass.Int32(8):
|
|
if hs < cutlass.Int32(11):
|
|
result = cutlass.Int32(6) # 8, 9, 10 → 6
|
|
if hs >= cutlass.Int32(11):
|
|
result = cutlass.Int32(7) # 11, 12 → 7
|
|
return result
|
|
|
|
|
|
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
|
|
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32).
|
|
|
|
Only handles positive values (NVFP4 scale factors are always positive).
|
|
Returns the uint8 bit pattern packed into an Int32.
|
|
|
|
Algorithm:
|
|
1. Handle zero → return 0
|
|
2. Normalize: double/halve val until in [1, 2), tracking floor(log2(val))
|
|
3. FP8 exponent = floor(log2(val)) + bias(7)
|
|
4. Mantissa = round((normalized - 1) * 8), clamp to [0, 7]
|
|
5. Handle subnormals (exponent < 1)
|
|
6. Pack: (exponent << 3) | mantissa
|
|
"""
|
|
result = cutlass.Int32(0) # default: zero
|
|
|
|
if val > cutlass.Float32(0.0):
|
|
# Clamp to FP8 E4M3 max non-NaN value (exp=15, mant=6 = 448.0)
|
|
clamped = cute.math.fmin(val, cutlass.Float32(448.0))
|
|
|
|
# Normalize to [1, 2) range, tracking floor(log2(clamped))
|
|
norm = clamped
|
|
exp_floor = cutlass.Int32(0)
|
|
|
|
# Double until >= 1 (for values < 1)
|
|
# At most 7 doublings needed (smallest normal ≈ 2^-6)
|
|
for _ in cutlass.range(7, unroll=1):
|
|
if norm < cutlass.Float32(1.0):
|
|
norm = norm * cutlass.Float32(2.0)
|
|
exp_floor = exp_floor - cutlass.Int32(1)
|
|
|
|
# Halve until < 2 (for values >= 2)
|
|
# At most 8 halvings needed (largest ≈ 240 < 256 = 2^8)
|
|
for _ in cutlass.range(8, unroll=1):
|
|
if norm >= cutlass.Float32(2.0):
|
|
norm = norm * cutlass.Float32(0.5)
|
|
exp_floor = exp_floor + cutlass.Int32(1)
|
|
|
|
# FP8 exponent = floor(log2(val)) + bias
|
|
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
|
|
fp8_exp = cute.math.fmin(fp8_exp, cutlass.Int32(15))
|
|
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
|
|
|
|
# Mantissa for normal: (norm - 1) * 8, round
|
|
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
|
|
mantissa = cutlass.Int32(mantissa_f) # round-to-nearest-even (matches __float2int_rn)
|
|
|
|
# Mantissa overflow: if rounded to 8, increment exponent and reset mantissa
|
|
# e.g., 250.0 → norm≈1.953, mantissa=round(7.625)=8 → exp+1, mant=0 → 256.0
|
|
if mantissa >= cutlass.Int32(8):
|
|
mantissa = cutlass.Int32(0)
|
|
fp8_exp = fp8_exp + cutlass.Int32(1)
|
|
|
|
mantissa = cute.math.fmin(mantissa, cutlass.Int32(7))
|
|
mantissa = cute.arch.fmax(mantissa, cutlass.Int32(0))
|
|
fp8_exp = cute.math.fmin(fp8_exp, cutlass.Int32(15))
|
|
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
|
|
|
|
# NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN.
|
|
# Saturate to max non-NaN (exp=15, mant=6 = 448.0).
|
|
if fp8_exp == cutlass.Int32(15):
|
|
if mantissa == cutlass.Int32(7):
|
|
mantissa = cutlass.Int32(6)
|
|
|
|
# Subnormal handling: if fp8_exp < 1, value is 2^(1-7) * m/8 = m * 2^(-9)
|
|
# m = round(clamped * 2^9) = round(clamped * 512)
|
|
if fp8_exp < cutlass.Int32(1):
|
|
sub_m_f = clamped * cutlass.Float32(512.0)
|
|
sub_m = cutlass.Int32(sub_m_f) # round-to-nearest-even
|
|
sub_m = cute.math.fmin(sub_m, cutlass.Int32(7))
|
|
sub_m = cute.arch.fmax(sub_m, cutlass.Int32(1))
|
|
mantissa = sub_m
|
|
fp8_exp = cutlass.Int32(0)
|
|
|
|
result = (fp8_exp << cutlass.Int32(3)) | mantissa
|
|
|
|
return result
|
|
|
|
|
|
def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
|
|
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32.
|
|
|
|
Normal: val = 2^(e-7) * (1 + m/8)
|
|
Subnormal (e=0): val = 2^(-7) * (m/8) = m / 1024
|
|
"""
|
|
mantissa = bits & cutlass.Int32(7)
|
|
exponent = (bits >> cutlass.Int32(3)) & cutlass.Int32(15)
|
|
|
|
# Compute 2^(e-7) by iterative doubling/halving from 1.0
|
|
scale = cutlass.Float32(1.0)
|
|
exp_delta = exponent - cutlass.Int32(FP8_E4M3_BIAS)
|
|
|
|
# Double for positive delta (max e=14, delta=7)
|
|
d = exp_delta
|
|
for _ in cutlass.range(7, unroll=1):
|
|
if d > cutlass.Int32(0):
|
|
scale = scale * cutlass.Float32(2.0)
|
|
d = d - cutlass.Int32(1)
|
|
|
|
# Halve for negative delta (min e=0, delta=-7)
|
|
d = exp_delta
|
|
for _ in cutlass.range(7, unroll=1):
|
|
if d < cutlass.Int32(0):
|
|
scale = scale * cutlass.Float32(0.5)
|
|
d = d + cutlass.Int32(1)
|
|
|
|
# Normal value
|
|
normal_val = (cutlass.Float32(1.0) + cutlass.Float32(mantissa) / cutlass.Float32(8.0)) * scale
|
|
|
|
# Subnormal value (e=0): val = m * 2^(-9) = m / 512
|
|
subnormal_val = cutlass.Float32(mantissa) / cutlass.Float32(512.0)
|
|
|
|
# Select
|
|
result = cutlass.Float32(0.0)
|
|
if exponent > cutlass.Int32(0):
|
|
result = normal_val
|
|
if exponent == cutlass.Int32(0):
|
|
if mantissa > cutlass.Int32(0):
|
|
result = subnormal_val
|
|
|
|
return result
|
|
|
|
|
|
def quantize_e2m1_nibble(
|
|
val: cutlass.Float32,
|
|
scale: cutlass.Float32,
|
|
) -> cutlass.Int32:
|
|
"""Quantize a single FP32 value to a 4-bit E2M1 nibble.
|
|
|
|
Returns uint4 nibble: bit 3 = sign, bits [2:0] = E2M1 index.
|
|
If scale ≈ 0, returns 0 (zero nibble).
|
|
"""
|
|
nibble = cutlass.Int32(0)
|
|
|
|
if scale > cutlass.Float32(1e-8):
|
|
scaled = val / scale
|
|
abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled)
|
|
abs_scaled = cute.math.fmin(abs_scaled, cutlass.Float32(6.0))
|
|
|
|
# half_step = round(|scaled| * 2) — round-to-nearest-even (matches __float2int_rn)
|
|
hs = cutlass.Int32(abs_scaled * cutlass.Float32(2.0))
|
|
hs = cute.math.fmin(hs, cutlass.Int32(12))
|
|
hs = cute.arch.fmax(hs, cutlass.Int32(0))
|
|
|
|
idx = half_step_to_e2m1_idx(hs)
|
|
|
|
if scaled < cutlass.Float32(0.0):
|
|
nibble = idx + cutlass.Int32(8)
|
|
if scaled >= cutlass.Float32(0.0):
|
|
nibble = idx
|
|
|
|
return nibble
|