diff --git a/patches/deepseek_v4.py b/patches/deepseek_v4.py index 57d0450..6997dfb 100644 --- a/patches/deepseek_v4.py +++ b/patches/deepseek_v4.py @@ -292,29 +292,45 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( # NVFP4: UE4M3 activation scales (not UE8M0) # scale_format_ is shared between SFA and SFB in the MMA descriptor, # so both must use the same format. For mxf4nvf4, both are UE4M3. - scale = amax / 448.0 # UE4M3 max = 448 - # Quantize to UE4M3: float → float8_e4m3fn - # E4M3: 1 sign + 4 exp + 3 mantissa, max normal = 448 - # For unsigned (UE4M3), we zero the sign bit - scale_f32 = scale.to(tl.float32) - scale_bits = scale_f32.to(tl.uint32, bitcast=True) - # Extract FP8 E4M3 bits: clamp to [0, 448], convert - # Simple approach: store as float32, reinterpret low 8 bits as E4M3 - # Triton doesn't have float8, so we compute E4M3 manually - # Actually, we can use the FP8 e4m3 type in triton - scale_e4m3 = scale.to(tl.float8e4nv) # hardware native E4M3 - scale_u8 = scale_e4m3.to(tl.uint8) - # Clear sign bit for UE4M3 - scale_u8 = scale_u8 & 0x7F + # E4M3 format: 1 sign + 4 exp + 3 mantissa, bias=7, max normal=448 + # UE4M3: sign bit cleared, same representation otherwise + scale = amax / 448.0 # Normalize so max activation maps to UE4M3 max + scale_bits = scale.to(tl.uint32, bitcast=True) + scale_exp = (scale_bits >> 23) & 0xFF + scale_mant = scale_bits & 0x7FFFFF + + # Convert FP32 → E4M3 manually + # Normal: exp in [1, 254], mantissa has 23 bits → take top 3 + # For E4M3: exp_bits = FP32_exp - 127 + 7 = FP32_exp - 120, clamped to [1, 15] + # mant_bits = FP32_mant >> 20 (top 3 of 23) + # Round: if any of the dropped mantissa bits are set, add 1 to mant + e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7 + # Handle subnormal: if FP32 exp < 121, E4M3 exp would be 0 (subnormal) + e4m3_exp = tl.maximum(e4m3_exp, 0) + e4m3_exp = tl.minimum(e4m3_exp, 15) + # Mantissa: top 3 bits of FP32 mantissa, with rounding + e4m3_mant = scale_mant >> 20 # top 3 bits + round_bit = (scale_mant >> 19) & 1 # 4th bit for rounding + e4m3_mant = e4m3_mant + round_bit + # If mantissa overflowed (8), carry into exponent + overflow = e4m3_mant >= 8 + e4m3_mant = tl.where(overflow, 0, e4m3_mant) + e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp) + e4m3_exp = tl.minimum(e4m3_exp, 15) + # Pack: UE4M3 = (exp << 3) | mant (no sign bit) + scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant # uint8 representation + + # Reconstruct the dequantized scale for FP8 activation quantization + # UE4M3 dequant: if exp == 0: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6 + # else: value = (1 + mant/8) * 2^(exp-7) + e4m3_value = tl.where( + e4m3_exp == 0, + (e4m3_mant.to(tl.float32) / 8.0) * (2.0 ** -6), + (1.0 + e4m3_mant.to(tl.float32) / 8.0) * (2.0 ** (e4m3_exp.to(tl.float32) - 7.0)) + ) - # Dequantize the rounded scale for FP8 activation quantization - # UE4M3 value = (1 + m/8) * 2^(e-7) for normal, or (m/8) * 2^(1-7) for subnormal - # For simplicity, reconstruct from the rounded FP8 value - rounded_scale = scale_e4m3.to(tl.float32) # dequantize UE4M3 → FP32 - # But we need the FP8 activation, quantized with this scale - # Recompute: activation_fp8 = activation / rounded_scale hidden_groups_f = tl.reshape(hidden, [num_groups, GROUP_K]) - scaled = hidden_groups_f * (1.0 / tl.maximum(rounded_scale, 1e-6))[:, None] + scaled = hidden_groups_f * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None] scaled = tl.reshape(scaled, [BLOCK_K]) fp8 = scaled.to(tl.float8e4nv) tl.store( @@ -325,7 +341,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( # Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements) scale_offsets = tl.arange(0, num_groups) - packed_scale = tl.sum(scale_u8.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32) + packed_scale = tl.sum(scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32) tl.store( x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, packed_scale,