fix: manual FP32→UE4M3 quant in Triton staging kernel
Triton can't cast float8e4nv → uint8 directly. Compute E4M3 bits manually: extract FP32 exponent/mantissa, convert to E4M3 format (4-bit exp + 3-bit mant), handle rounding and overflow, reconstruct dequantized value for FP8 activation quantization.
This commit is contained in:
@@ -292,29 +292,45 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
# NVFP4: UE4M3 activation scales (not UE8M0)
|
||||
# scale_format_ is shared between SFA and SFB in the MMA descriptor,
|
||||
# so both must use the same format. For mxf4nvf4, both are UE4M3.
|
||||
scale = amax / 448.0 # UE4M3 max = 448
|
||||
# Quantize to UE4M3: float → float8_e4m3fn
|
||||
# E4M3: 1 sign + 4 exp + 3 mantissa, max normal = 448
|
||||
# For unsigned (UE4M3), we zero the sign bit
|
||||
scale_f32 = scale.to(tl.float32)
|
||||
scale_bits = scale_f32.to(tl.uint32, bitcast=True)
|
||||
# Extract FP8 E4M3 bits: clamp to [0, 448], convert
|
||||
# Simple approach: store as float32, reinterpret low 8 bits as E4M3
|
||||
# Triton doesn't have float8, so we compute E4M3 manually
|
||||
# Actually, we can use the FP8 e4m3 type in triton
|
||||
scale_e4m3 = scale.to(tl.float8e4nv) # hardware native E4M3
|
||||
scale_u8 = scale_e4m3.to(tl.uint8)
|
||||
# Clear sign bit for UE4M3
|
||||
scale_u8 = scale_u8 & 0x7F
|
||||
# E4M3 format: 1 sign + 4 exp + 3 mantissa, bias=7, max normal=448
|
||||
# UE4M3: sign bit cleared, same representation otherwise
|
||||
scale = amax / 448.0 # Normalize so max activation maps to UE4M3 max
|
||||
scale_bits = scale.to(tl.uint32, bitcast=True)
|
||||
scale_exp = (scale_bits >> 23) & 0xFF
|
||||
scale_mant = scale_bits & 0x7FFFFF
|
||||
|
||||
# Convert FP32 → E4M3 manually
|
||||
# Normal: exp in [1, 254], mantissa has 23 bits → take top 3
|
||||
# For E4M3: exp_bits = FP32_exp - 127 + 7 = FP32_exp - 120, clamped to [1, 15]
|
||||
# mant_bits = FP32_mant >> 20 (top 3 of 23)
|
||||
# Round: if any of the dropped mantissa bits are set, add 1 to mant
|
||||
e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7
|
||||
# Handle subnormal: if FP32 exp < 121, E4M3 exp would be 0 (subnormal)
|
||||
e4m3_exp = tl.maximum(e4m3_exp, 0)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
# Mantissa: top 3 bits of FP32 mantissa, with rounding
|
||||
e4m3_mant = scale_mant >> 20 # top 3 bits
|
||||
round_bit = (scale_mant >> 19) & 1 # 4th bit for rounding
|
||||
e4m3_mant = e4m3_mant + round_bit
|
||||
# If mantissa overflowed (8), carry into exponent
|
||||
overflow = e4m3_mant >= 8
|
||||
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
|
||||
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
# Pack: UE4M3 = (exp << 3) | mant (no sign bit)
|
||||
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant # uint8 representation
|
||||
|
||||
# Reconstruct the dequantized scale for FP8 activation quantization
|
||||
# UE4M3 dequant: if exp == 0: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6
|
||||
# else: value = (1 + mant/8) * 2^(exp-7)
|
||||
e4m3_value = tl.where(
|
||||
e4m3_exp == 0,
|
||||
(e4m3_mant.to(tl.float32) / 8.0) * (2.0 ** -6),
|
||||
(1.0 + e4m3_mant.to(tl.float32) / 8.0) * (2.0 ** (e4m3_exp.to(tl.float32) - 7.0))
|
||||
)
|
||||
|
||||
# Dequantize the rounded scale for FP8 activation quantization
|
||||
# UE4M3 value = (1 + m/8) * 2^(e-7) for normal, or (m/8) * 2^(1-7) for subnormal
|
||||
# For simplicity, reconstruct from the rounded FP8 value
|
||||
rounded_scale = scale_e4m3.to(tl.float32) # dequantize UE4M3 → FP32
|
||||
# But we need the FP8 activation, quantized with this scale
|
||||
# Recompute: activation_fp8 = activation / rounded_scale
|
||||
hidden_groups_f = tl.reshape(hidden, [num_groups, GROUP_K])
|
||||
scaled = hidden_groups_f * (1.0 / tl.maximum(rounded_scale, 1e-6))[:, None]
|
||||
scaled = hidden_groups_f * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
|
||||
scaled = tl.reshape(scaled, [BLOCK_K])
|
||||
fp8 = scaled.to(tl.float8e4nv)
|
||||
tl.store(
|
||||
@@ -325,7 +341,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
|
||||
# Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
|
||||
scale_offsets = tl.arange(0, num_groups)
|
||||
packed_scale = tl.sum(scale_u8.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32)
|
||||
packed_scale = tl.sum(scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32)
|
||||
tl.store(
|
||||
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
|
||||
packed_scale,
|
||||
|
||||
Reference in New Issue
Block a user