From 5faf9916eb4b7e2068b3edb2f3b480bf2a09fd80 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 16:12:36 +0000 Subject: [PATCH] fix: UE4M3 activation scales + group_size=16 for NVFP4 mega_moe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The mxf4nvf4 MMA instruction shares scale_format_ between SFA and SFB. For NVFP4 (UE4M3), both activation and weight scales must be UE4M3. Changes to _stage_deepseek_v4_mega_moe_inputs_kernel: - GROUP_K=16 (was 32) — NVFP4 scale_vec::4X has group_size=16 - Scale quantization: float → float8_e4m3fn (UE4M3) instead of UE8M0 exponent extraction (>> 23). Pack 4 UE4M3 bytes per int32. - FP8 activation quantized against UE4M3 rounded scale Also updated class docstring (was stale MXFP4 conversion description). --- patches/deepseek_v4.py | 50 +++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/patches/deepseek_v4.py b/patches/deepseek_v4.py index 1a272c2..57d0450 100644 --- a/patches/deepseek_v4.py +++ b/patches/deepseek_v4.py @@ -289,16 +289,32 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( amax = tl.max(hidden_groups, axis=1) amax = tl.maximum(amax, 1.0e-4) - scale = amax / 448.0 - scale_bits = scale.to(tl.uint32, bitcast=True) - scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( - tl.uint32 - ) - scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) - rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) + # NVFP4: UE4M3 activation scales (not UE8M0) + # scale_format_ is shared between SFA and SFB in the MMA descriptor, + # so both must use the same format. For mxf4nvf4, both are UE4M3. + scale = amax / 448.0 # UE4M3 max = 448 + # Quantize to UE4M3: float → float8_e4m3fn + # E4M3: 1 sign + 4 exp + 3 mantissa, max normal = 448 + # For unsigned (UE4M3), we zero the sign bit + scale_f32 = scale.to(tl.float32) + scale_bits = scale_f32.to(tl.uint32, bitcast=True) + # Extract FP8 E4M3 bits: clamp to [0, 448], convert + # Simple approach: store as float32, reinterpret low 8 bits as E4M3 + # Triton doesn't have float8, so we compute E4M3 manually + # Actually, we can use the FP8 e4m3 type in triton + scale_e4m3 = scale.to(tl.float8e4nv) # hardware native E4M3 + scale_u8 = scale_e4m3.to(tl.uint8) + # Clear sign bit for UE4M3 + scale_u8 = scale_u8 & 0x7F - hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) - scaled = hidden_groups * (1.0 / rounded_scale)[:, None] + # Dequantize the rounded scale for FP8 activation quantization + # UE4M3 value = (1 + m/8) * 2^(e-7) for normal, or (m/8) * 2^(1-7) for subnormal + # For simplicity, reconstruct from the rounded FP8 value + rounded_scale = scale_e4m3.to(tl.float32) # dequantize UE4M3 → FP32 + # But we need the FP8 activation, quantized with this scale + # Recompute: activation_fp8 = activation / rounded_scale + hidden_groups_f = tl.reshape(hidden, [num_groups, GROUP_K]) + scaled = hidden_groups_f * (1.0 / tl.maximum(rounded_scale, 1e-6))[:, None] scaled = tl.reshape(scaled, [BLOCK_K]) fp8 = scaled.to(tl.float8e4nv) tl.store( @@ -307,8 +323,9 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( mask=k_mask, ) + # Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements) scale_offsets = tl.arange(0, num_groups) - packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) + packed_scale = tl.sum(scale_u8.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32) tl.store( x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, packed_scale, @@ -399,7 +416,7 @@ def _stage_deepseek_v4_mega_moe_inputs( hidden_size, top_k, BLOCK_K=block_k, - GROUP_K=32, + GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X) BLOCK_TOPK=block_topk, num_warps=4, ) @@ -428,14 +445,11 @@ class DeepseekV4MegaMoEExperts(nn.Module): """MegaMoE experts for DeepSeek V4 with NVFP4 quantization. Loads NVFP4 expert weights (E2M1 packed uint8 + float8_e4m3fn block scales - + float32 global scales) and converts them to MXFP4 format for the - DeepGEMM fp8_fp4_mega_moe kernel at finalize_weights time. + + float32 global scales) and feeds them natively to the DeepGEMM + fp8_nvfp4_mega_moe kernel (kind::mxf4nvf4.scale_vec::4X). - NVFP4 → MXFP4 conversion: - 1. Unpack E2M1 FP4 → BF16 - 2. Dequantize with UE8M0 block_scale * float32 global_scale - 3. Re-quantize BF16 → MXFP4 (E2M1 + UE8M0, group_size=32) - 4. Feed to deep_gemm transform_weights_for_mega_moe + No conversion to MXFP4. Experts stay NVFP4. The global scale (weight_scale_2) + is folded into the block scales before kernel consumption. """ _symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {}