fix: manual FP32→UE4M3 quant in Triton staging kernel

Triton can't cast float8e4nv → uint8 directly. Compute E4M3 bits manually:
extract FP32 exponent/mantissa, convert to E4M3 format (4-bit exp + 3-bit mant),
handle rounding and overflow, reconstruct dequantized value for FP8 activation quantization.
This commit is contained in:
2026-05-11 16:38:49 +00:00
parent 436109081c
commit c4891e9ee2

View File

@@ -292,29 +292,45 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
# NVFP4: UE4M3 activation scales (not UE8M0)
# scale_format_ is shared between SFA and SFB in the MMA descriptor,
# so both must use the same format. For mxf4nvf4, both are UE4M3.
scale = amax / 448.0 # UE4M3 max = 448
# Quantize to UE4M3: float → float8_e4m3fn
# E4M3: 1 sign + 4 exp + 3 mantissa, max normal = 448
# For unsigned (UE4M3), we zero the sign bit
scale_f32 = scale.to(tl.float32)
scale_bits = scale_f32.to(tl.uint32, bitcast=True)
# Extract FP8 E4M3 bits: clamp to [0, 448], convert
# Simple approach: store as float32, reinterpret low 8 bits as E4M3
# Triton doesn't have float8, so we compute E4M3 manually
# Actually, we can use the FP8 e4m3 type in triton
scale_e4m3 = scale.to(tl.float8e4nv) # hardware native E4M3
scale_u8 = scale_e4m3.to(tl.uint8)
# Clear sign bit for UE4M3
scale_u8 = scale_u8 & 0x7F
# E4M3 format: 1 sign + 4 exp + 3 mantissa, bias=7, max normal=448
# UE4M3: sign bit cleared, same representation otherwise
scale = amax / 448.0 # Normalize so max activation maps to UE4M3 max
scale_bits = scale.to(tl.uint32, bitcast=True)
scale_exp = (scale_bits >> 23) & 0xFF
scale_mant = scale_bits & 0x7FFFFF
# Convert FP32 → E4M3 manually
# Normal: exp in [1, 254], mantissa has 23 bits → take top 3
# For E4M3: exp_bits = FP32_exp - 127 + 7 = FP32_exp - 120, clamped to [1, 15]
# mant_bits = FP32_mant >> 20 (top 3 of 23)
# Round: if any of the dropped mantissa bits are set, add 1 to mant
e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7
# Handle subnormal: if FP32 exp < 121, E4M3 exp would be 0 (subnormal)
e4m3_exp = tl.maximum(e4m3_exp, 0)
e4m3_exp = tl.minimum(e4m3_exp, 15)
# Mantissa: top 3 bits of FP32 mantissa, with rounding
e4m3_mant = scale_mant >> 20 # top 3 bits
round_bit = (scale_mant >> 19) & 1 # 4th bit for rounding
e4m3_mant = e4m3_mant + round_bit
# If mantissa overflowed (8), carry into exponent
overflow = e4m3_mant >= 8
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
e4m3_exp = tl.minimum(e4m3_exp, 15)
# Pack: UE4M3 = (exp << 3) | mant (no sign bit)
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant # uint8 representation
# Reconstruct the dequantized scale for FP8 activation quantization
# UE4M3 dequant: if exp == 0: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6
# else: value = (1 + mant/8) * 2^(exp-7)
e4m3_value = tl.where(
e4m3_exp == 0,
(e4m3_mant.to(tl.float32) / 8.0) * (2.0 ** -6),
(1.0 + e4m3_mant.to(tl.float32) / 8.0) * (2.0 ** (e4m3_exp.to(tl.float32) - 7.0))
)
# Dequantize the rounded scale for FP8 activation quantization
# UE4M3 value = (1 + m/8) * 2^(e-7) for normal, or (m/8) * 2^(1-7) for subnormal
# For simplicity, reconstruct from the rounded FP8 value
rounded_scale = scale_e4m3.to(tl.float32) # dequantize UE4M3 → FP32
# But we need the FP8 activation, quantized with this scale
# Recompute: activation_fp8 = activation / rounded_scale
hidden_groups_f = tl.reshape(hidden, [num_groups, GROUP_K])
scaled = hidden_groups_f * (1.0 / tl.maximum(rounded_scale, 1e-6))[:, None]
scaled = hidden_groups_f * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
scaled = tl.reshape(scaled, [BLOCK_K])
fp8 = scaled.to(tl.float8e4nv)
tl.store(
@@ -325,7 +341,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
# Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
scale_offsets = tl.arange(0, num_groups)
packed_scale = tl.sum(scale_u8.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32)
packed_scale = tl.sum(scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32)
tl.store(
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
packed_scale,