diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 50b89eb35..243220989 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -34,6 +34,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, ) from vllm.platforms import current_platform @@ -55,6 +57,49 @@ class Fp8MoeBackend(Enum): XPU = "XPU" +def _get_priority_backends( + moe_config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, +) -> list[Fp8MoeBackend]: + """ + Get available backends in priority order based on platform and config. + + This function can be extended to become more complex as needed. + """ + + _AVAILABLE_BACKENDS = [ + Fp8MoeBackend.AITER, + Fp8MoeBackend.FLASHINFER_TRTLLM, + Fp8MoeBackend.FLASHINFER_CUTLASS, + Fp8MoeBackend.DEEPGEMM, + Fp8MoeBackend.VLLM_CUTLASS, + Fp8MoeBackend.TRITON, + Fp8MoeBackend.MARLIN, + Fp8MoeBackend.BATCHED_DEEPGEMM, + Fp8MoeBackend.BATCHED_VLLM_CUTLASS, + Fp8MoeBackend.BATCHED_TRITON, + Fp8MoeBackend.XPU, + ] + + def _move_to_front(backends: list[Fp8MoeBackend], backend: Fp8MoeBackend) -> None: + backends.insert(0, backends.pop(backends.index(backend))) + + # On Hopper for Block Fp8, prefer Triton for TP and FI CUTLASS for EP. + if ( + current_platform.is_cuda() + and current_platform.is_device_capability(90) + and activation_key == kFp8Dynamic128Sym + and weight_key == kFp8Static128BlockSym + ): + if moe_config.moe_parallel_config.ep_size > 1: + _move_to_front(_AVAILABLE_BACKENDS, Fp8MoeBackend.FLASHINFER_CUTLASS) + else: + _move_to_front(_AVAILABLE_BACKENDS, Fp8MoeBackend.TRITON) + + return _AVAILABLE_BACKENDS + + def backend_to_kernel_cls( backend: Fp8MoeBackend, ) -> type[mk.FusedMoEPermuteExpertsUnpermute]: @@ -151,19 +196,7 @@ def select_fp8_moe_backend( return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON) # NOTE: the kernels are selected in the following order. - AVAILABLE_BACKENDS = [ - Fp8MoeBackend.AITER, - Fp8MoeBackend.FLASHINFER_TRTLLM, - Fp8MoeBackend.FLASHINFER_CUTLASS, - Fp8MoeBackend.DEEPGEMM, - Fp8MoeBackend.BATCHED_DEEPGEMM, - Fp8MoeBackend.VLLM_CUTLASS, - Fp8MoeBackend.BATCHED_VLLM_CUTLASS, - Fp8MoeBackend.TRITON, - Fp8MoeBackend.BATCHED_TRITON, - Fp8MoeBackend.MARLIN, - Fp8MoeBackend.XPU, - ] + AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key) # NOTE(rob): We need to peak into the P/F selection to determine # if we are using the batched or standard expert format, which