Files
nvfp4-megamoe-kernel/dsv4/kernels/gemm/fp4_quant.py

212 lines
8.1 KiB
Python

"""
NVFP4 quantization primitives for CuTeDSL kernels.
Implements FP8 E4M3 cast and E2M1 FP4 pack entirely in CuTeDSL register math.
No shortcuts — proper bit-level quantization matching the Python/CUDA reference.
FP8 E4M3 format (VERIFIED against PyTorch — bias is 7, NOT 8):
- 1 sign bit, 4 exponent bits, 3 mantissa bits, bias = 7
- Normal: (-1)^s * 2^(e-7) * (1 + m/8), e in [1, 14]
- Subnormal: (-1)^s * 2^(1-7) * (m/8) = m * 2^(-9), e = 0
- Max normal: 2^8 * (1 + 6/8) = 448.0 (exp=15,mant=7 is NaN; exp=15,mant=0-6 are valid)
- Min positive normal: 2^(-6) ≈ 0.015625
- Min positive subnormal: 2^(-9) ≈ 0.001953
NVFP4 format:
- 16-element microblocks
- FP8 E4M3 block scale: amax / 6 (max E2M1 magnitude = 6)
- Per-element E2M1 quantize: nearest of {0, 0.5, 1, 1.5, 2, 3, 4, 6}
- Two 4-bit nibbles packed into one uint8 byte: (odd << 4) | even
CuTeDSL constraints:
- Variables defined before `if` blocks can be reassigned inside and read after.
- Both branches of `if` are compiled; use `cutlass.const_expr` to eliminate dead code.
- `range(unroll=1)` produces runtime loops (not unrolled at trace time).
- No log2, frexp, bit_cast, or reinterpret_cast for scalars.
"""
import cutlass
import cutlass.cute as cute
FP8_E4M3_BIAS = 7
def half_step_to_e2m1_idx(hs: cutlass.Int32) -> cutlass.Int32:
"""Map half-step value (0-12) to E2M1 index (0-7).
Matches the CUDA kernel's half_step_to_e4m3() and the Python LUT:
0→0, 1→1, 2→2, 3→3, 4→4, 5→4, 6→5, 7→5, 8→6, 9→6, 10→6, 11→7, 12→7
"""
result = cutlass.Int32(7) # default for 11, 12
if hs < cutlass.Int32(5):
if hs < cutlass.Int32(4):
result = hs # 0, 1, 2,3 → identity
if hs >= cutlass.Int32(4):
result = cutlass.Int32(4) # 4 → 4
if hs >= cutlass.Int32(5):
if hs < cutlass.Int32(8):
if hs < cutlass.Int32(6):
result = cutlass.Int32(4) # 5 → 4
if hs >= cutlass.Int32(6):
result = cutlass.Int32(5) # 6, 7 → 5
if hs >= cutlass.Int32(8):
if hs < cutlass.Int32(11):
result = cutlass.Int32(6) # 8, 9, 10 → 6
if hs >= cutlass.Int32(11):
result = cutlass.Int32(7) # 11, 12 → 7
return result
def fp8_e4m3_from_float32(val: cutlass.Float32) -> cutlass.Int32:
"""Convert a positive Float32 value to FP8 E4M3 bit pattern (returned as Int32).
Only handles positive values (NVFP4 scale factors are always positive).
Returns the uint8 bit pattern packed into an Int32.
Algorithm:
1. Handle zero → return 0
2. Normalize: double/halve val until in [1, 2), tracking floor(log2(val))
3. FP8 exponent = floor(log2(val)) + bias(7)
4. Mantissa = round((normalized - 1) * 8), clamp to [0, 7]
5. Handle subnormals (exponent < 1)
6. Pack: (exponent << 3) | mantissa
"""
result = cutlass.Int32(0) # default: zero
if val > cutlass.Float32(0.0):
# Clamp to FP8 E4M3 max non-NaN value (exp=15, mant=6 = 448.0)
clamped = cute.math.fmin(val, cutlass.Float32(448.0))
# Normalize to [1, 2) range, tracking floor(log2(clamped))
norm = clamped
exp_floor = cutlass.Int32(0)
# Double until >= 1 (for values < 1)
# At most 7 doublings needed (smallest normal ≈ 2^-6)
for _ in cutlass.range(7, unroll=1):
if norm < cutlass.Float32(1.0):
norm = norm * cutlass.Float32(2.0)
exp_floor = exp_floor - cutlass.Int32(1)
# Halve until < 2 (for values >= 2)
# At most 8 halvings needed (largest ≈ 240 < 256 = 2^8)
for _ in cutlass.range(8, unroll=1):
if norm >= cutlass.Float32(2.0):
norm = norm * cutlass.Float32(0.5)
exp_floor = exp_floor + cutlass.Int32(1)
# FP8 exponent = floor(log2(val)) + bias
fp8_exp = exp_floor + cutlass.Int32(FP8_E4M3_BIAS)
fp8_exp = cute.math.fmin(fp8_exp, cutlass.Int32(15))
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
# Mantissa for normal: (norm - 1) * 8, round
mantissa_f = (norm - cutlass.Float32(1.0)) * cutlass.Float32(8.0)
mantissa = cutlass.Int32(mantissa_f) # round-to-nearest-even (matches __float2int_rn)
# Mantissa overflow: if rounded to 8, increment exponent and reset mantissa
# e.g., 250.0 → norm≈1.953, mantissa=round(7.625)=8 → exp+1, mant=0 → 256.0
if mantissa >= cutlass.Int32(8):
mantissa = cutlass.Int32(0)
fp8_exp = fp8_exp + cutlass.Int32(1)
mantissa = cute.math.fmin(mantissa, cutlass.Int32(7))
mantissa = cute.arch.fmax(mantissa, cutlass.Int32(0))
fp8_exp = cute.math.fmin(fp8_exp, cutlass.Int32(15))
fp8_exp = cute.arch.fmax(fp8_exp, cutlass.Int32(0))
# NaN guard: FP8 E4M3 with exp=15 and mant=7 is NaN.
# Saturate to max non-NaN (exp=15, mant=6 = 448.0).
if fp8_exp == cutlass.Int32(15):
if mantissa == cutlass.Int32(7):
mantissa = cutlass.Int32(6)
# Subnormal handling: if fp8_exp < 1, value is 2^(1-7) * m/8 = m * 2^(-9)
# m = round(clamped * 2^9) = round(clamped * 512)
if fp8_exp < cutlass.Int32(1):
sub_m_f = clamped * cutlass.Float32(512.0)
sub_m = cutlass.Int32(sub_m_f) # round-to-nearest-even
sub_m = cute.math.fmin(sub_m, cutlass.Int32(7))
sub_m = cute.arch.fmax(sub_m, cutlass.Int32(1))
mantissa = sub_m
fp8_exp = cutlass.Int32(0)
result = (fp8_exp << cutlass.Int32(3)) | mantissa
return result
def fp8_e4m3_to_float32(bits: cutlass.Int32) -> cutlass.Float32:
"""Convert FP8 E4M3 bit pattern (in Int32) back to Float32.
Normal: val = 2^(e-7) * (1 + m/8)
Subnormal (e=0): val = 2^(-7) * (m/8) = m / 1024
"""
mantissa = bits & cutlass.Int32(7)
exponent = (bits >> cutlass.Int32(3)) & cutlass.Int32(15)
# Compute 2^(e-7) by iterative doubling/halving from 1.0
scale = cutlass.Float32(1.0)
exp_delta = exponent - cutlass.Int32(FP8_E4M3_BIAS)
# Double for positive delta (max e=14, delta=7)
d = exp_delta
for _ in cutlass.range(7, unroll=1):
if d > cutlass.Int32(0):
scale = scale * cutlass.Float32(2.0)
d = d - cutlass.Int32(1)
# Halve for negative delta (min e=0, delta=-7)
d = exp_delta
for _ in cutlass.range(7, unroll=1):
if d < cutlass.Int32(0):
scale = scale * cutlass.Float32(0.5)
d = d + cutlass.Int32(1)
# Normal value
normal_val = (cutlass.Float32(1.0) + cutlass.Float32(mantissa) / cutlass.Float32(8.0)) * scale
# Subnormal value (e=0): val = m * 2^(-9) = m / 512
subnormal_val = cutlass.Float32(mantissa) / cutlass.Float32(512.0)
# Select
result = cutlass.Float32(0.0)
if exponent > cutlass.Int32(0):
result = normal_val
if exponent == cutlass.Int32(0):
if mantissa > cutlass.Int32(0):
result = subnormal_val
return result
def quantize_e2m1_nibble(
val: cutlass.Float32,
scale: cutlass.Float32,
) -> cutlass.Int32:
"""Quantize a single FP32 value to a 4-bit E2M1 nibble.
Returns uint4 nibble: bit 3 = sign, bits [2:0] = E2M1 index.
If scale ≈ 0, returns 0 (zero nibble).
"""
nibble = cutlass.Int32(0)
if scale > cutlass.Float32(1e-8):
scaled = val / scale
abs_scaled = cute.arch.fmax(scaled, cutlass.Float32(0.0) - scaled)
abs_scaled = cute.math.fmin(abs_scaled, cutlass.Float32(6.0))
# half_step = round(|scaled| * 2) — round-to-nearest-even (matches __float2int_rn)
hs = cutlass.Int32(abs_scaled * cutlass.Float32(2.0))
hs = cute.math.fmin(hs, cutlass.Int32(12))
hs = cute.arch.fmax(hs, cutlass.Int32(0))
idx = half_step_to_e2m1_idx(hs)
if scaled < cutlass.Float32(0.0):
nibble = idx + cutlass.Int32(8)
if scaled >= cutlass.Float32(0.0):
nibble = idx
return nibble