""" NVFP4 quantization utilities — E2M1 packing and UE4M3 scale handling. Core math layer for the NVFP4 mega_moe kernel rewrite. """ # E2M1 magnitude lookup table (positive values only) # Index 0-7 maps to: 0, 0.5, 1, 1.5, 2, 3, 4, 6 def e2m1_magnitude(index: Int) -> Float64: if index == 0: return 0.0 if index == 1: return 0.5 if index == 2: return 1.0 if index == 3: return 1.5 if index == 4: return 2.0 if index == 5: return 3.0 if index == 6: return 4.0 if index == 7: return 6.0 return 0.0 def quantize_e2m1(value: Float64) -> UInt8: """Quantize a float64 value to E2M1 (4-bit), returning the 4-bit nibble with sign.""" var sign = 0 var abs_val = value if value < 0.0: sign = 1 abs_val = -value # Find best E2M1 match var best_idx = 0 var best_err = abs_val # error for idx=0 for i in range(1, 8): mag = e2m1_magnitude(i) err = abs(abs_val - mag) if err < best_err: best_err = err best_idx = i return (sign << 3) | best_idx def unpack_e2m1(packed: UInt8, idx: Int) -> Float64: """Unpack one E2M1 value from a packed byte. idx=0 -> low nibble, idx=1 -> high nibble. """ nibble: UInt8 if idx == 0: nibble = packed & 0x0F else: nibble = (packed >> 4) & 0x0F # keep sign bit sign = (nibble >> 3) & 1 mag_idx = nibble & 0x07 magnitude = e2m1_magnitude(Int(mag_idx)) if sign: return -magnitude return magnitude def dequantize_nvfp4_weight( packed_weight: UInt8, block_scale: Float64, group_offset: Int, ) -> Float64: """Dequantize a single NVFP4 weight element. weight = E2M1_magnitude * block_scale (global_scale is already folded into block_scale) """ e2m1_value = unpack_e2m1(packed_weight, group_offset) return e2m1_value * block_scale def main() raises: # Test E2M1 quantization round-trip print("E2M1 quantization test:") for val in [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]: packed = quantize_e2m1(val) unpacked = unpack_e2m1(packed, 0) print(" ", val, " -> E2M1 -> ", unpacked) # Test packed byte (two E2M1 values) print("\nPacked byte test:") lo = 1.5 hi = -3.0 packed = (quantize_e2m1(hi) << 4) | quantize_e2m1(lo) print(" lo=", lo, " hi=", hi, " packed=", packed) print(" unpack lo=", unpack_e2m1(packed, 0), " unpack hi=", unpack_e2m1(packed, 1)) # Test NVFP4 dequantization print("\nNVFP4 dequantization test:") packed_w = UInt8(0x36) # low=6.0, high=3.0 scale = 0.5 print(" packed=0x36, scale=0.5, lo=", dequantize_nvfp4_weight(packed_w, scale, 0)) print(" packed=0x36, scale=0.5, hi=", dequantize_nvfp4_weight(packed_w, scale, 1))