[MoE Refactor][17/N] Apply Refactor to Bf16 (#31827)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Yongye Zhu
2026-01-15 12:53:40 -08:00
committed by GitHub
parent 8c11001ba2
commit 31c29257c8
12 changed files with 257 additions and 87 deletions

View File

@@ -0,0 +1,5 @@
model_name: "Qwen/Qwen3-30B-A3B"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"

View File

@@ -8,3 +8,4 @@ Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml
Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ht.yaml
Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml
Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml
Qwen3-30B-A3B-BF16-triton.yaml

View File

@@ -0,0 +1,7 @@
model_name: "meta-llama/Llama-4-Scout-17B-16E-Instruct"
accuracy_threshold: 0.92
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_FLASHINFER_MOE_FP16: "1"

View File

@@ -0,0 +1,6 @@
model_name: "meta-llama/Llama-4-Scout-17B-16E-Instruct"
accuracy_threshold: 0.92
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"

View File

@@ -0,0 +1,7 @@
model_name: "mistralai/Mixtral-8x7B-v0.1"
accuracy_threshold: 0.58
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_FLASHINFER_MOE_FP16: "1"

View File

@@ -0,0 +1,5 @@
model_name: "mistralai/Mixtral-8x7B-v0.1"
accuracy_threshold: 0.58
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"

View File

@@ -0,0 +1,7 @@
model_name: "Qwen/Qwen3-30B-A3B"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2 --enable-expert-parallel"
env:
VLLM_USE_FLASHINFER_MOE_FP16: "1"

View File

@@ -0,0 +1,5 @@
model_name: "Qwen/Qwen3-30B-A3B"
accuracy_threshold: 0.88
num_questions: 1319
num_fewshot: 5
server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"

View File

@@ -11,3 +11,7 @@ Qwen3-30B-A3B-NvFp4-ModelOpt-marlin.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-fi-trtllm.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass.yaml
Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass-dp-ep.yaml
Llama-4-Scout-BF16-fi-cutlass.yaml
Llama-4-Scout-BF16-triton.yaml
Mixtral-8x7B-BF16-fi-cutlass.yaml
Mixtral-8x7B-BF16-triton.yaml

View File

@@ -11,3 +11,5 @@ Qwen3-30B-A3B-Fp8-CT-Channel-vllm-cutlass.yaml
Llama-4-Scout-Fp8-ModelOpt-fi-cutlass.yaml
Llama-4-Scout-Fp8-ModelOpt-marlin.yaml
Llama-4-Scout-Fp8-ModelOpt-triton.yaml
Qwen3-30B-A3B-BF16-fi-cutlass.yaml
Qwen3-30B-A3B-BF16-triton.yaml

View File

@@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
from torch.nn import Module
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
logger = init_logger(__name__)
class UnquantizedMoeBackend(Enum):
FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
AITER = "ROCm AITER"
TRITON = "TRITON"
CPU = "CPU"
XPU = "XPU"
# NOTE(zyongye): Unsupported backend means backend
# that is not conform with Modular kernel format.
# We will directly call the kernel for those backend
UNSUPPORTED_BACKEND = [
UnquantizedMoeBackend.CPU,
UnquantizedMoeBackend.XPU,
]
def select_unquantized_moe_backend(
use_ep: bool,
use_dp: bool,
) -> UnquantizedMoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
def _make_log_backend(backend: UnquantizedMoeBackend):
return f"Using {backend.value} backend for Unquantized MoE"
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
flashinfer_cutlass_moe_enabled = (
has_flashinfer_cutlass_fused_moe()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and use_ep
and (not use_dp)
and current_platform.get_device_capability()[0] >= 9
)
if current_platform.is_rocm():
if rocm_aiter_moe_enabled:
backend = UnquantizedMoeBackend.AITER
else:
backend = UnquantizedMoeBackend.TRITON
if current_platform.is_cuda():
if flashinfer_cutlass_moe_enabled:
backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
else:
if use_ep and (not use_dp):
logger.info_once(
"FlashInfer CUTLASS MoE is available for EP"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
scope="local",
)
elif use_dp:
logger.info_once(
"FlashInfer CUTLASS MoE is currently not available for DP.",
scope="local",
)
backend = UnquantizedMoeBackend.TRITON
if current_platform.is_xpu():
backend = UnquantizedMoeBackend.XPU
if current_platform.is_cpu():
backend = UnquantizedMoeBackend.CPU
logger.info_once(_make_log_backend(backend), scope="local")
return backend
def convert_to_unquantized_kernel_format(
unquantized_backend: UnquantizedMoeBackend,
layer: Module,
w13_weight: torch.Tensor | None = None,
w2_weight: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if unquantized_backend == UnquantizedMoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
elif unquantized_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
return w13_weight, w2_weight
def make_unquantized_moe_kernel(
layer: torch.nn.Module,
backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
) -> tuple[mk.FusedMoEModularKernel | None, bool]:
use_inplace = True
if backend in UNSUPPORTED_BACKEND:
return None, use_inplace
if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
out_dtype=layer.params_dtype,
quant_config=quant_config,
tp_rank=moe_config.moe_parallel_config.tp_rank,
tp_size=moe_config.moe_parallel_config.tp_size,
ep_rank=moe_config.moe_parallel_config.ep_rank,
ep_size=moe_config.moe_parallel_config.ep_size,
),
)
use_inplace = False
elif backend == UnquantizedMoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
AiterExperts(quant_config),
)
elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(quant_config),
)
return kernel, use_inplace

View File

@@ -4,10 +4,10 @@
import torch
import torch.nn.functional as F
from torch.nn import Module
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
@@ -16,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
biased_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
@@ -28,19 +25,15 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31,
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
UnquantizedMoeBackend,
convert_to_unquantized_kernel_format,
make_unquantized_moe_kernel,
select_unquantized_moe_backend,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
@@ -57,41 +50,13 @@ logger = init_logger(__name__)
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
# --8<-- [end:unquantized_fused_moe]
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
self.flashinfer_cutlass_moe_enabled = (
has_flashinfer_cutlass_fused_moe()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and self.moe.moe_parallel_config.use_ep
and self.moe.moe_parallel_config.dp_size == 1
and current_platform.get_device_capability()[0] >= 9
self.unquantized_backend = select_unquantized_moe_backend(
use_ep=self.moe.moe_parallel_config.use_ep,
use_dp=self.moe.moe_parallel_config.dp_size > 1,
)
if self.flashinfer_cutlass_moe_enabled:
logger.info_once(
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
)
else:
if (
self.moe.moe_parallel_config.use_ep
and self.moe.moe_parallel_config.dp_size == 1
):
logger.info_once(
"FlashInfer CUTLASS MoE is available for EP"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
scope="local",
)
elif self.moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"FlashInfer CUTLASS MoE is currently not available for DP.",
scope="local",
)
self.kernel: mk.FusedMoEModularKernel | None = None
@property
def supports_eplb(self) -> bool:
@@ -105,7 +70,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if self.rocm_aiter_moe_enabled:
if self.unquantized_backend == UnquantizedMoeBackend.AITER:
return None
else:
return super().maybe_make_prepare_finalize(routing_tables)
@@ -197,6 +162,33 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
return weight
def _setup_kernel(
self,
layer: Module,
w13: torch.Tensor,
w2: torch.Tensor,
) -> None:
# Shuffle weights to runtime format.
w13, w2 = convert_to_unquantized_kernel_format(
self.unquantized_backend,
layer=layer,
w13_weight=w13,
w2_weight=w2,
)
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
# Setup Modular Kernel for TP Case
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.kernel, self.use_inplace = make_unquantized_moe_kernel(
layer=layer,
backend=self.unquantized_backend,
quant_config=self.moe_quant_config,
moe_config=self.moe,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
@@ -204,7 +196,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
if current_platform.is_xpu():
if self.unquantized_backend == UnquantizedMoeBackend.XPU:
import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
@@ -214,7 +206,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_prepack=True,
experts_start_id=ep_rank_start,
)
elif current_platform.is_cpu():
elif self.unquantized_backend == UnquantizedMoeBackend.CPU:
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
@@ -246,45 +238,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_cuda_alike():
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
self.use_inplace = True
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
AiterExperts(self.moe_quant_config),
shared_experts=None,
)
elif self.flashinfer_cutlass_moe_enabled:
self.use_inplace = False
# Swap halves to arrange as [w3; w1] (kernel expectation)
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
replace_parameter(layer, "w13_weight", w13_weight)
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
out_dtype=layer.params_dtype,
quant_config=self.moe_quant_config,
tp_rank=self.moe.moe_parallel_config.tp_rank,
tp_size=self.moe.moe_parallel_config.tp_size,
ep_rank=self.moe.moe_parallel_config.ep_rank,
ep_size=self.moe.moe_parallel_config.ep_size,
),
)
else:
self.use_inplace = True
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(self.moe_quant_config),
shared_experts=None,
)
self._setup_kernel(
layer=layer,
w13=layer.w13_weight,
w2=layer.w2_weight,
)
def apply(
self,
@@ -316,6 +274,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
router_logits=router_logits,