Files
nvfp4-megamoe-kernel/src/quantize.mojo
biondizzle c2b752c2fe Initial: TileLang NVFP4 mega_moe kernel package
- nvfp4_mega_moe_full: drop-in replacement for deep_gemm.mega.fp8_nvfp4_mega_moe
- transform_nvfp4_weights_for_mega_moe: weight transformation (tested)
- SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe: API-matching stubs
- MEGA_MOE_STATIC=1 support for pipeline testing
- pyproject.toml for pip install
2026-05-13 15:44:51 +00:00

96 lines
2.8 KiB
Mojo

"""
NVFP4 quantization utilities — E2M1 packing and UE4M3 scale handling.
Core math layer for the NVFP4 mega_moe kernel rewrite.
"""
# E2M1 magnitude lookup table (positive values only)
# Index 0-7 maps to: 0, 0.5, 1, 1.5, 2, 3, 4, 6
def e2m1_magnitude(index: Int) -> Float64:
if index == 0: return 0.0
if index == 1: return 0.5
if index == 2: return 1.0
if index == 3: return 1.5
if index == 4: return 2.0
if index == 5: return 3.0
if index == 6: return 4.0
if index == 7: return 6.0
return 0.0
def quantize_e2m1(value: Float64) -> UInt8:
"""Quantize a float64 value to E2M1 (4-bit), returning the 4-bit nibble with sign."""
var sign = 0
var abs_val = value
if value < 0.0:
sign = 1
abs_val = -value
# Find best E2M1 match
var best_idx = 0
var best_err = abs_val # error for idx=0
for i in range(1, 8):
mag = e2m1_magnitude(i)
err = abs(abs_val - mag)
if err < best_err:
best_err = err
best_idx = i
return (sign << 3) | best_idx
def unpack_e2m1(packed: UInt8, idx: Int) -> Float64:
"""Unpack one E2M1 value from a packed byte.
idx=0 -> low nibble, idx=1 -> high nibble.
"""
nibble: UInt8
if idx == 0:
nibble = packed & 0x0F
else:
nibble = (packed >> 4) & 0x0F # keep sign bit
sign = (nibble >> 3) & 1
mag_idx = nibble & 0x07
magnitude = e2m1_magnitude(Int(mag_idx))
if sign:
return -magnitude
return magnitude
def dequantize_nvfp4_weight(
packed_weight: UInt8,
block_scale: Float64,
group_offset: Int,
) -> Float64:
"""Dequantize a single NVFP4 weight element.
weight = E2M1_magnitude * block_scale
(global_scale is already folded into block_scale)
"""
e2m1_value = unpack_e2m1(packed_weight, group_offset)
return e2m1_value * block_scale
def main() raises:
# Test E2M1 quantization round-trip
print("E2M1 quantization test:")
for val in [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]:
packed = quantize_e2m1(val)
unpacked = unpack_e2m1(packed, 0)
print(" ", val, " -> E2M1 -> ", unpacked)
# Test packed byte (two E2M1 values)
print("\nPacked byte test:")
lo = 1.5
hi = -3.0
packed = (quantize_e2m1(hi) << 4) | quantize_e2m1(lo)
print(" lo=", lo, " hi=", hi, " packed=", packed)
print(" unpack lo=", unpack_e2m1(packed, 0), " unpack hi=", unpack_e2m1(packed, 1))
# Test NVFP4 dequantization
print("\nNVFP4 dequantization test:")
packed_w = UInt8(0x36) # low=6.0, high=3.0
scale = 0.5
print(" packed=0x36, scale=0.5, lo=", dequantize_nvfp4_weight(packed_w, scale, 0))
print(" packed=0x36, scale=0.5, hi=", dequantize_nvfp4_weight(packed_w, scale, 1))