feat: wire DeepGEMM NVFP4 mega_moe kernel into vLLM patch
- DeepseekV4MegaMoEExperts now uses native NVFP4 path - finalize_weights: transform_nvfp4_weights_for_mega_moe() instead of NVFP4→BF16→MXFP4 conversion - forward: fp8_nvfp4_mega_moe() with recipe=(1,1,16) - Experts stay in NVFP4. No MXFP4 conversion. Period.
This commit is contained in:
@@ -623,54 +623,24 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
return
|
||||
|
||||
self._check_runtime_supported()
|
||||
import vllm.third_party.deep_gemm as deep_gemm
|
||||
from deep_gemm.mega import transform_nvfp4_weights_for_mega_moe
|
||||
|
||||
# === NVFP4 → BF16 → MXFP4 conversion ===
|
||||
# The DeepGEMM mega_moe kernel expects MXFP4 format:
|
||||
# - E2M1 packed uint8 (same as NVFP4)
|
||||
# - UE8M0 uint8 block scales, group_size=32
|
||||
# NVFP4 has:
|
||||
# - E2M1 packed uint8 (same)
|
||||
# - E8M0 float8_e4m3fn block scales, group_size=16
|
||||
# - float32 global_scale and input_scale
|
||||
# We dequant NVFP4→BF16 then requant BF16→MXFP4.
|
||||
|
||||
w13_bf16 = self._nvfp4_to_bf16(
|
||||
self.w13_weight.data, self.w13_weight_scale.data,
|
||||
self.w13_weight_scale_2.data, self.w13_input_scale.data,
|
||||
)
|
||||
w2_bf16 = self._nvfp4_to_bf16(
|
||||
self.w2_weight.data, self.w2_weight_scale.data,
|
||||
self.w2_weight_scale_2.data, self.w2_input_scale.data,
|
||||
)
|
||||
|
||||
# Re-quantize BF16 → MXFP4 (E2M1 + UE8M0, group_size=32)
|
||||
MXFP4_GROUP_SIZE = 32
|
||||
w13_mxfp4_weight, w13_mxfp4_scale = self._bf16_to_mxfp4(
|
||||
w13_bf16, MXFP4_GROUP_SIZE)
|
||||
w2_mxfp4_weight, w2_mxfp4_scale = self._bf16_to_mxfp4(
|
||||
w2_bf16, MXFP4_GROUP_SIZE)
|
||||
|
||||
# Transform into DeepGEMM mega_moe layout
|
||||
w13_scale = deep_gemm.transform_sf_into_required_layout(
|
||||
w13_mxfp4_scale.contiguous(),
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
(1, 32),
|
||||
self.num_local_experts,
|
||||
)
|
||||
w2_scale = deep_gemm.transform_sf_into_required_layout(
|
||||
w2_mxfp4_scale.contiguous(),
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
(1, 32),
|
||||
self.num_local_experts,
|
||||
)
|
||||
# === Native NVFP4 path ===
|
||||
# The DeepGEMM nvfp4 mega_moe kernel consumes NVFP4 directly:
|
||||
# - E2M1 packed uint8 (same as checkpoint)
|
||||
# - UE4M3 block scales (float8_e4m3fn), group_size=16
|
||||
# - float32 global scale folded into block scales
|
||||
# No conversion to MXFP4. Experts stay NVFP4.
|
||||
|
||||
# Fold global scales into block scales and transform for the kernel
|
||||
self._transformed_l1_weights, self._transformed_l2_weights = (
|
||||
deep_gemm.transform_weights_for_mega_moe(
|
||||
(w13_mxfp4_weight.view(torch.int8).contiguous(), w13_scale),
|
||||
(w2_mxfp4_weight.view(torch.int8).contiguous(), w2_scale),
|
||||
transform_nvfp4_weights_for_mega_moe(
|
||||
(self.w13_weight.data.contiguous(),
|
||||
self.w13_weight_scale.data.contiguous()),
|
||||
(self.w2_weight.data.contiguous(),
|
||||
self.w2_weight_scale.data.contiguous()),
|
||||
l1_weight_scale_2=self.w13_weight_scale_2.data.contiguous(),
|
||||
l2_weight_scale_2=self.w2_weight_scale_2.data.contiguous(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -864,7 +834,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
|
||||
assert self._transformed_l1_weights is not None
|
||||
assert self._transformed_l2_weights is not None
|
||||
deep_gemm.fp8_fp4_mega_moe(
|
||||
from deep_gemm.mega import fp8_nvfp4_mega_moe
|
||||
fp8_nvfp4_mega_moe(
|
||||
y,
|
||||
self._transformed_l1_weights,
|
||||
self._transformed_l2_weights,
|
||||
|
||||
Reference in New Issue
Block a user