[UX] Add --moe-backend arg for explicit kernel selection (#33807)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -1066,7 +1066,6 @@ class FusedMoEParallelConfig:
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
use_ep = (
|
||||
dp_size_ * pcp_size_ * tp_size_ > 1
|
||||
and vllm_parallel_config.enable_expert_parallel
|
||||
@@ -1155,6 +1154,7 @@ class FusedMoEConfig:
|
||||
# Defaults to in_dtype if not specified.
|
||||
router_logits_dtype: torch.dtype | None = None
|
||||
|
||||
moe_backend: str = "auto"
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
has_bias: bool = False
|
||||
is_act_and_mul: bool = True
|
||||
|
||||
@@ -198,7 +198,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
x = x[0].permute(2, 0, 1)
|
||||
num_experts, max_tokens, hidden_dim_by_2 = x.shape
|
||||
hidden_dim = hidden_dim_by_2 * 2
|
||||
assert envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm"
|
||||
logger.info_once(
|
||||
"Quantization is fused with DeepEP nvfp4 dispatch for "
|
||||
"FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1"
|
||||
|
||||
@@ -550,6 +550,7 @@ class FusedMoE(CustomOp):
|
||||
num_logical_experts=self.logical_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=moe_in_dtype,
|
||||
moe_backend=vllm_config.kernel_config.moe_backend,
|
||||
router_logits_dtype=router_logits_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config.kernel import MoEBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
@@ -180,6 +181,25 @@ def backend_to_kernel_cls(
|
||||
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
|
||||
|
||||
|
||||
def map_fp8_backend(runner_backend: MoEBackend) -> Fp8MoeBackend:
|
||||
"""Map user's MoEBackend to Fp8MoeBackend."""
|
||||
mapping = {
|
||||
"triton": Fp8MoeBackend.TRITON,
|
||||
"deep_gemm": Fp8MoeBackend.DEEPGEMM,
|
||||
"cutlass": Fp8MoeBackend.VLLM_CUTLASS,
|
||||
"flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
"flashinfer_cutlass": Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
"marlin": Fp8MoeBackend.MARLIN,
|
||||
"aiter": Fp8MoeBackend.AITER,
|
||||
}
|
||||
if backend := mapping.get(runner_backend):
|
||||
return backend
|
||||
raise ValueError(
|
||||
f"moe_backend='{runner_backend}' is not supported for FP8 MoE. "
|
||||
f"Expected one of {list(mapping.keys())}."
|
||||
)
|
||||
|
||||
|
||||
def select_fp8_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
@@ -242,6 +262,45 @@ def select_fp8_moe_backend(
|
||||
return backend, k_cls
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
# Handle explicit moe_backend from user.
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_fp8_backend(runner_backend)
|
||||
# For batched activation format, use batched variants if available.
|
||||
if activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
|
||||
if requested_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
requested_backend = Fp8MoeBackend.BATCHED_DEEPGEMM
|
||||
elif requested_backend == Fp8MoeBackend.TRITON:
|
||||
requested_backend = Fp8MoeBackend.BATCHED_TRITON
|
||||
elif requested_backend == Fp8MoeBackend.VLLM_CUTLASS:
|
||||
requested_backend = Fp8MoeBackend.BATCHED_VLLM_CUTLASS
|
||||
|
||||
if (
|
||||
requested_backend
|
||||
in [
|
||||
Fp8MoeBackend.VLLM_CUTLASS,
|
||||
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
|
||||
]
|
||||
and not allow_vllm_cutlass
|
||||
):
|
||||
raise ValueError(
|
||||
"vLLM CUTLASS FP8 MoE backend is disabled for this configuration."
|
||||
)
|
||||
|
||||
# Handle FLASHINFER_TRTLLM specially (no kernel class).
|
||||
if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
supported, reason = is_supported_config_trtllm_fp8(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(requested_backend))
|
||||
return requested_backend, None
|
||||
raise ValueError(_make_log_unsupported(requested_backend, reason))
|
||||
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
|
||||
# Handle explicit FlashInfer FP8 configuration.
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"):
|
||||
if not envs.VLLM_USE_FLASHINFER_MOE_FP8:
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config.kernel import MoEBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
@@ -103,6 +104,23 @@ def backend_to_kernel_cls(
|
||||
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
|
||||
|
||||
|
||||
def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend:
|
||||
"""Map user's MoEBackend to NvFp4MoeBackend."""
|
||||
mapping = {
|
||||
"cutlass": NvFp4MoeBackend.VLLM_CUTLASS,
|
||||
"flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
"flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
"flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
"marlin": NvFp4MoeBackend.MARLIN,
|
||||
}
|
||||
if backend := mapping.get(runner_backend):
|
||||
return backend
|
||||
raise ValueError(
|
||||
f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. "
|
||||
f"Expected one of {list(mapping.keys())}."
|
||||
)
|
||||
|
||||
|
||||
def select_nvfp4_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
@@ -170,6 +188,23 @@ def select_nvfp4_moe_backend(
|
||||
return backend, k_cls
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
# Handle explicit moe_backend from user.
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_nvfp4_backend(runner_backend)
|
||||
if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
supported, reason = is_supported_config_trtllm(
|
||||
config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(requested_backend))
|
||||
return requested_backend, None
|
||||
raise ValueError(_make_log_unsupported(requested_backend, reason))
|
||||
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
|
||||
if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
||||
# If the user rejects FlashInfer remove those backends.
|
||||
|
||||
@@ -9,6 +9,7 @@ from torch.nn import Module
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config.kernel import MoEBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
@@ -51,6 +52,22 @@ UNSUPPORTED_BACKEND = [
|
||||
]
|
||||
|
||||
|
||||
def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend:
|
||||
"""Map user's MoEBackend to UnquantizedMoeBackend."""
|
||||
mapping = {
|
||||
"triton": UnquantizedMoeBackend.TRITON,
|
||||
"flashinfer_trtllm": UnquantizedMoeBackend.FLASHINFER_TRTLLM,
|
||||
"flashinfer_cutlass": UnquantizedMoeBackend.FLASHINFER_CUTLASS,
|
||||
"aiter": UnquantizedMoeBackend.AITER,
|
||||
}
|
||||
if backend := mapping.get(runner_backend):
|
||||
return backend
|
||||
raise ValueError(
|
||||
f"moe_backend='{runner_backend}' is not supported for unquantized MoE. "
|
||||
f"Expected one of {list(mapping.keys())}."
|
||||
)
|
||||
|
||||
|
||||
def select_unquantized_moe_backend(
|
||||
moe_config: FusedMoEConfig,
|
||||
use_ep: bool,
|
||||
@@ -64,8 +81,6 @@ def select_unquantized_moe_backend(
|
||||
def _make_log_backend(backend: UnquantizedMoeBackend):
|
||||
return f"Using {backend.value} backend for Unquantized MoE"
|
||||
|
||||
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
activation_format = (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts
|
||||
if moe_config.moe_parallel_config.use_batched_activation_format
|
||||
@@ -77,20 +92,49 @@ def select_unquantized_moe_backend(
|
||||
moe_config=moe_config,
|
||||
activation_format=activation_format,
|
||||
)
|
||||
flashinfer_trtllm_moe_enabled = (
|
||||
has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP16
|
||||
and trtllm_supported
|
||||
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
|
||||
)
|
||||
flashinfer_trtllm_available = has_flashinfer() and trtllm_supported
|
||||
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
|
||||
flashinfer_cutlass_moe_enabled = (
|
||||
flashinfer_cutlass_available = (
|
||||
has_flashinfer_cutlass_fused_moe()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP16
|
||||
and use_ep
|
||||
and (not use_dp)
|
||||
and current_platform.has_device_capability(90)
|
||||
)
|
||||
flashinfer_trtllm_moe_enabled = (
|
||||
flashinfer_trtllm_available
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP16
|
||||
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
|
||||
)
|
||||
flashinfer_cutlass_moe_enabled = (
|
||||
flashinfer_cutlass_available and envs.VLLM_USE_FLASHINFER_MOE_FP16
|
||||
)
|
||||
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# Handle explicit moe_backend from user.
|
||||
runner_backend = moe_config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_unquantized_backend(runner_backend)
|
||||
if requested_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
|
||||
if not flashinfer_trtllm_available:
|
||||
raise ValueError(
|
||||
"FlashInfer TRTLLM MoE backend is not available for this "
|
||||
"configuration."
|
||||
)
|
||||
elif requested_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
|
||||
if not flashinfer_cutlass_available:
|
||||
raise ValueError(
|
||||
"FlashInfer CUTLASS MoE backend is not available for this "
|
||||
"configuration."
|
||||
)
|
||||
elif requested_backend == UnquantizedMoeBackend.AITER and not (
|
||||
current_platform.is_rocm() and rocm_aiter_moe_enabled
|
||||
):
|
||||
raise ValueError(
|
||||
"ROCm AITer MoE backend is not available for this configuration."
|
||||
)
|
||||
logger.info_once(_make_log_backend(requested_backend), scope="local")
|
||||
return requested_backend
|
||||
|
||||
if current_platform.is_rocm():
|
||||
if rocm_aiter_moe_enabled:
|
||||
backend = UnquantizedMoeBackend.AITER
|
||||
|
||||
Reference in New Issue
Block a user