From b1cf4232ee7d23e3592fafe8007de5f4a6f4815e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 06:22:11 +0000 Subject: [PATCH] feat: wire DeepGEMM NVFP4 mega_moe kernel into vLLM patch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - DeepseekV4MegaMoEExperts now uses native NVFP4 path - finalize_weights: transform_nvfp4_weights_for_mega_moe() instead of NVFP4→BF16→MXFP4 conversion - forward: fp8_nvfp4_mega_moe() with recipe=(1,1,16) - Experts stay in NVFP4. No MXFP4 conversion. Period. --- patches/deepseek_v4.py | 63 ++++++++++++------------------------------ 1 file changed, 17 insertions(+), 46 deletions(-) diff --git a/patches/deepseek_v4.py b/patches/deepseek_v4.py index 97ed479..8141058 100644 --- a/patches/deepseek_v4.py +++ b/patches/deepseek_v4.py @@ -623,54 +623,24 @@ class DeepseekV4MegaMoEExperts(nn.Module): return self._check_runtime_supported() - import vllm.third_party.deep_gemm as deep_gemm + from deep_gemm.mega import transform_nvfp4_weights_for_mega_moe - # === NVFP4 → BF16 → MXFP4 conversion === - # The DeepGEMM mega_moe kernel expects MXFP4 format: - # - E2M1 packed uint8 (same as NVFP4) - # - UE8M0 uint8 block scales, group_size=32 - # NVFP4 has: - # - E2M1 packed uint8 (same) - # - E8M0 float8_e4m3fn block scales, group_size=16 - # - float32 global_scale and input_scale - # We dequant NVFP4→BF16 then requant BF16→MXFP4. - - w13_bf16 = self._nvfp4_to_bf16( - self.w13_weight.data, self.w13_weight_scale.data, - self.w13_weight_scale_2.data, self.w13_input_scale.data, - ) - w2_bf16 = self._nvfp4_to_bf16( - self.w2_weight.data, self.w2_weight_scale.data, - self.w2_weight_scale_2.data, self.w2_input_scale.data, - ) - - # Re-quantize BF16 → MXFP4 (E2M1 + UE8M0, group_size=32) - MXFP4_GROUP_SIZE = 32 - w13_mxfp4_weight, w13_mxfp4_scale = self._bf16_to_mxfp4( - w13_bf16, MXFP4_GROUP_SIZE) - w2_mxfp4_weight, w2_mxfp4_scale = self._bf16_to_mxfp4( - w2_bf16, MXFP4_GROUP_SIZE) - - # Transform into DeepGEMM mega_moe layout - w13_scale = deep_gemm.transform_sf_into_required_layout( - w13_mxfp4_scale.contiguous(), - 2 * self.intermediate_size, - self.hidden_size, - (1, 32), - self.num_local_experts, - ) - w2_scale = deep_gemm.transform_sf_into_required_layout( - w2_mxfp4_scale.contiguous(), - self.hidden_size, - self.intermediate_size, - (1, 32), - self.num_local_experts, - ) + # === Native NVFP4 path === + # The DeepGEMM nvfp4 mega_moe kernel consumes NVFP4 directly: + # - E2M1 packed uint8 (same as checkpoint) + # - UE4M3 block scales (float8_e4m3fn), group_size=16 + # - float32 global scale folded into block scales + # No conversion to MXFP4. Experts stay NVFP4. + # Fold global scales into block scales and transform for the kernel self._transformed_l1_weights, self._transformed_l2_weights = ( - deep_gemm.transform_weights_for_mega_moe( - (w13_mxfp4_weight.view(torch.int8).contiguous(), w13_scale), - (w2_mxfp4_weight.view(torch.int8).contiguous(), w2_scale), + transform_nvfp4_weights_for_mega_moe( + (self.w13_weight.data.contiguous(), + self.w13_weight_scale.data.contiguous()), + (self.w2_weight.data.contiguous(), + self.w2_weight_scale.data.contiguous()), + l1_weight_scale_2=self.w13_weight_scale_2.data.contiguous(), + l2_weight_scale_2=self.w2_weight_scale_2.data.contiguous(), ) ) @@ -864,7 +834,8 @@ class DeepseekV4MegaMoEExperts(nn.Module): assert self._transformed_l1_weights is not None assert self._transformed_l2_weights is not None - deep_gemm.fp8_fp4_mega_moe( + from deep_gemm.mega import fp8_nvfp4_mega_moe + fp8_nvfp4_mega_moe( y, self._transformed_l1_weights, self._transformed_l2_weights,