[UX] Add --moe-backend arg for explicit kernel selection (#33807)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Michael Goin
2026-02-25 20:44:44 -05:00
committed by GitHub
parent 1976356ee6
commit de527e1cec
37 changed files with 260 additions and 140 deletions

View File

@@ -1066,7 +1066,6 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the experts are split
between the 4 devices.
"""
use_ep = (
dp_size_ * pcp_size_ * tp_size_ > 1
and vllm_parallel_config.enable_expert_parallel
@@ -1155,6 +1154,7 @@ class FusedMoEConfig:
# Defaults to in_dtype if not specified.
router_logits_dtype: torch.dtype | None = None
moe_backend: str = "auto"
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
is_act_and_mul: bool = True

View File

@@ -198,7 +198,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
x = x[0].permute(2, 0, 1)
num_experts, max_tokens, hidden_dim_by_2 = x.shape
hidden_dim = hidden_dim_by_2 * 2
assert envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm"
logger.info_once(
"Quantization is fused with DeepEP nvfp4 dispatch for "
"FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1"

View File

@@ -550,6 +550,7 @@ class FusedMoE(CustomOp):
num_logical_experts=self.logical_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
moe_backend=vllm_config.kernel_config.moe_backend,
router_logits_dtype=router_logits_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,

View File

@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
@@ -180,6 +181,25 @@ def backend_to_kernel_cls(
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
def map_fp8_backend(runner_backend: MoEBackend) -> Fp8MoeBackend:
"""Map user's MoEBackend to Fp8MoeBackend."""
mapping = {
"triton": Fp8MoeBackend.TRITON,
"deep_gemm": Fp8MoeBackend.DEEPGEMM,
"cutlass": Fp8MoeBackend.VLLM_CUTLASS,
"flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM,
"flashinfer_cutlass": Fp8MoeBackend.FLASHINFER_CUTLASS,
"marlin": Fp8MoeBackend.MARLIN,
"aiter": Fp8MoeBackend.AITER,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for FP8 MoE. "
f"Expected one of {list(mapping.keys())}."
)
def select_fp8_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey | None,
@@ -242,6 +262,45 @@ def select_fp8_moe_backend(
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_fp8_backend(runner_backend)
# For batched activation format, use batched variants if available.
if activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
if requested_backend == Fp8MoeBackend.DEEPGEMM:
requested_backend = Fp8MoeBackend.BATCHED_DEEPGEMM
elif requested_backend == Fp8MoeBackend.TRITON:
requested_backend = Fp8MoeBackend.BATCHED_TRITON
elif requested_backend == Fp8MoeBackend.VLLM_CUTLASS:
requested_backend = Fp8MoeBackend.BATCHED_VLLM_CUTLASS
if (
requested_backend
in [
Fp8MoeBackend.VLLM_CUTLASS,
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
]
and not allow_vllm_cutlass
):
raise ValueError(
"vLLM CUTLASS FP8 MoE backend is disabled for this configuration."
)
# Handle FLASHINFER_TRTLLM specially (no kernel class).
if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(requested_backend))
return requested_backend, None
raise ValueError(_make_log_unsupported(requested_backend, reason))
return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format
)
# Handle explicit FlashInfer FP8 configuration.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP8:

View File

@@ -6,6 +6,7 @@ import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
@@ -103,6 +104,23 @@ def backend_to_kernel_cls(
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend:
"""Map user's MoEBackend to NvFp4MoeBackend."""
mapping = {
"cutlass": NvFp4MoeBackend.VLLM_CUTLASS,
"flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM,
"flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS,
"flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL,
"marlin": NvFp4MoeBackend.MARLIN,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. "
f"Expected one of {list(mapping.keys())}."
)
def select_nvfp4_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey | None,
@@ -170,6 +188,23 @@ def select_nvfp4_moe_backend(
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_nvfp4_backend(runner_backend)
if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(requested_backend))
return requested_backend, None
raise ValueError(_make_log_unsupported(requested_backend, reason))
return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format
)
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
# If the user rejects FlashInfer remove those backends.

View File

@@ -9,6 +9,7 @@ from torch.nn import Module
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@@ -51,6 +52,22 @@ UNSUPPORTED_BACKEND = [
]
def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend:
"""Map user's MoEBackend to UnquantizedMoeBackend."""
mapping = {
"triton": UnquantizedMoeBackend.TRITON,
"flashinfer_trtllm": UnquantizedMoeBackend.FLASHINFER_TRTLLM,
"flashinfer_cutlass": UnquantizedMoeBackend.FLASHINFER_CUTLASS,
"aiter": UnquantizedMoeBackend.AITER,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for unquantized MoE. "
f"Expected one of {list(mapping.keys())}."
)
def select_unquantized_moe_backend(
moe_config: FusedMoEConfig,
use_ep: bool,
@@ -64,8 +81,6 @@ def select_unquantized_moe_backend(
def _make_log_backend(backend: UnquantizedMoeBackend):
return f"Using {backend.value} backend for Unquantized MoE"
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if moe_config.moe_parallel_config.use_batched_activation_format
@@ -77,20 +92,49 @@ def select_unquantized_moe_backend(
moe_config=moe_config,
activation_format=activation_format,
)
flashinfer_trtllm_moe_enabled = (
has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and trtllm_supported
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
)
flashinfer_trtllm_available = has_flashinfer() and trtllm_supported
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
flashinfer_cutlass_moe_enabled = (
flashinfer_cutlass_available = (
has_flashinfer_cutlass_fused_moe()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and use_ep
and (not use_dp)
and current_platform.has_device_capability(90)
)
flashinfer_trtllm_moe_enabled = (
flashinfer_trtllm_available
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
)
flashinfer_cutlass_moe_enabled = (
flashinfer_cutlass_available and envs.VLLM_USE_FLASHINFER_MOE_FP16
)
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# Handle explicit moe_backend from user.
runner_backend = moe_config.moe_backend
if runner_backend != "auto":
requested_backend = map_unquantized_backend(runner_backend)
if requested_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
if not flashinfer_trtllm_available:
raise ValueError(
"FlashInfer TRTLLM MoE backend is not available for this "
"configuration."
)
elif requested_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
if not flashinfer_cutlass_available:
raise ValueError(
"FlashInfer CUTLASS MoE backend is not available for this "
"configuration."
)
elif requested_backend == UnquantizedMoeBackend.AITER and not (
current_platform.is_rocm() and rocm_aiter_moe_enabled
):
raise ValueError(
"ROCm AITer MoE backend is not available for this configuration."
)
logger.info_once(_make_log_backend(requested_backend), scope="local")
return requested_backend
if current_platform.is_rocm():
if rocm_aiter_moe_enabled:
backend = UnquantizedMoeBackend.AITER