- nvfp4_mega_moe_full: drop-in replacement for deep_gemm.mega.fp8_nvfp4_mega_moe - transform_nvfp4_weights_for_mega_moe: weight transformation (tested) - SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe: API-matching stubs - MEGA_MOE_STATIC=1 support for pipeline testing - pyproject.toml for pip install
96 lines
2.8 KiB
Mojo
96 lines
2.8 KiB
Mojo
"""
|
|
NVFP4 quantization utilities — E2M1 packing and UE4M3 scale handling.
|
|
Core math layer for the NVFP4 mega_moe kernel rewrite.
|
|
"""
|
|
|
|
# E2M1 magnitude lookup table (positive values only)
|
|
# Index 0-7 maps to: 0, 0.5, 1, 1.5, 2, 3, 4, 6
|
|
def e2m1_magnitude(index: Int) -> Float64:
|
|
if index == 0: return 0.0
|
|
if index == 1: return 0.5
|
|
if index == 2: return 1.0
|
|
if index == 3: return 1.5
|
|
if index == 4: return 2.0
|
|
if index == 5: return 3.0
|
|
if index == 6: return 4.0
|
|
if index == 7: return 6.0
|
|
return 0.0
|
|
|
|
|
|
def quantize_e2m1(value: Float64) -> UInt8:
|
|
"""Quantize a float64 value to E2M1 (4-bit), returning the 4-bit nibble with sign."""
|
|
var sign = 0
|
|
var abs_val = value
|
|
if value < 0.0:
|
|
sign = 1
|
|
abs_val = -value
|
|
|
|
# Find best E2M1 match
|
|
var best_idx = 0
|
|
var best_err = abs_val # error for idx=0
|
|
|
|
for i in range(1, 8):
|
|
mag = e2m1_magnitude(i)
|
|
err = abs(abs_val - mag)
|
|
if err < best_err:
|
|
best_err = err
|
|
best_idx = i
|
|
|
|
return (sign << 3) | best_idx
|
|
|
|
|
|
def unpack_e2m1(packed: UInt8, idx: Int) -> Float64:
|
|
"""Unpack one E2M1 value from a packed byte.
|
|
idx=0 -> low nibble, idx=1 -> high nibble.
|
|
"""
|
|
nibble: UInt8
|
|
if idx == 0:
|
|
nibble = packed & 0x0F
|
|
else:
|
|
nibble = (packed >> 4) & 0x0F # keep sign bit
|
|
|
|
sign = (nibble >> 3) & 1
|
|
mag_idx = nibble & 0x07
|
|
magnitude = e2m1_magnitude(Int(mag_idx))
|
|
|
|
if sign:
|
|
return -magnitude
|
|
return magnitude
|
|
|
|
|
|
def dequantize_nvfp4_weight(
|
|
packed_weight: UInt8,
|
|
block_scale: Float64,
|
|
group_offset: Int,
|
|
) -> Float64:
|
|
"""Dequantize a single NVFP4 weight element.
|
|
weight = E2M1_magnitude * block_scale
|
|
(global_scale is already folded into block_scale)
|
|
"""
|
|
e2m1_value = unpack_e2m1(packed_weight, group_offset)
|
|
return e2m1_value * block_scale
|
|
|
|
|
|
def main() raises:
|
|
# Test E2M1 quantization round-trip
|
|
print("E2M1 quantization test:")
|
|
for val in [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]:
|
|
packed = quantize_e2m1(val)
|
|
unpacked = unpack_e2m1(packed, 0)
|
|
print(" ", val, " -> E2M1 -> ", unpacked)
|
|
|
|
# Test packed byte (two E2M1 values)
|
|
print("\nPacked byte test:")
|
|
lo = 1.5
|
|
hi = -3.0
|
|
packed = (quantize_e2m1(hi) << 4) | quantize_e2m1(lo)
|
|
print(" lo=", lo, " hi=", hi, " packed=", packed)
|
|
print(" unpack lo=", unpack_e2m1(packed, 0), " unpack hi=", unpack_e2m1(packed, 1))
|
|
|
|
# Test NVFP4 dequantization
|
|
print("\nNVFP4 dequantization test:")
|
|
packed_w = UInt8(0x36) # low=6.0, high=3.0
|
|
scale = 0.5
|
|
print(" packed=0x36, scale=0.5, lo=", dequantize_nvfp4_weight(packed_w, scale, 0))
|
|
print(" packed=0x36, scale=0.5, hi=", dequantize_nvfp4_weight(packed_w, scale, 1))
|