[Model Bash]: Improve FP8 Oracle for Config Specific Kernel Selection (#34260)
Signed-off-by: Elizabeth Thomas <email2eliza@gmail.com> Signed-off-by: Robert Shaw <robertgshaw2-redhat@h100-02.nemg-001.lab.rdu2.dc.redhat.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <robertgshaw2-redhat@h100-02.nemg-001.lab.rdu2.dc.redhat.com> Co-authored-by: Robert Shaw <robertgshaw2@gmail.com>
This commit is contained in:
@@ -34,6 +34,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -55,6 +57,49 @@ class Fp8MoeBackend(Enum):
|
||||
XPU = "XPU"
|
||||
|
||||
|
||||
def _get_priority_backends(
|
||||
moe_config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> list[Fp8MoeBackend]:
|
||||
"""
|
||||
Get available backends in priority order based on platform and config.
|
||||
|
||||
This function can be extended to become more complex as needed.
|
||||
"""
|
||||
|
||||
_AVAILABLE_BACKENDS = [
|
||||
Fp8MoeBackend.AITER,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
Fp8MoeBackend.DEEPGEMM,
|
||||
Fp8MoeBackend.VLLM_CUTLASS,
|
||||
Fp8MoeBackend.TRITON,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
Fp8MoeBackend.BATCHED_DEEPGEMM,
|
||||
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
|
||||
Fp8MoeBackend.BATCHED_TRITON,
|
||||
Fp8MoeBackend.XPU,
|
||||
]
|
||||
|
||||
def _move_to_front(backends: list[Fp8MoeBackend], backend: Fp8MoeBackend) -> None:
|
||||
backends.insert(0, backends.pop(backends.index(backend)))
|
||||
|
||||
# On Hopper for Block Fp8, prefer Triton for TP and FI CUTLASS for EP.
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)
|
||||
and activation_key == kFp8Dynamic128Sym
|
||||
and weight_key == kFp8Static128BlockSym
|
||||
):
|
||||
if moe_config.moe_parallel_config.ep_size > 1:
|
||||
_move_to_front(_AVAILABLE_BACKENDS, Fp8MoeBackend.FLASHINFER_CUTLASS)
|
||||
else:
|
||||
_move_to_front(_AVAILABLE_BACKENDS, Fp8MoeBackend.TRITON)
|
||||
|
||||
return _AVAILABLE_BACKENDS
|
||||
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: Fp8MoeBackend,
|
||||
) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
|
||||
@@ -151,19 +196,7 @@ def select_fp8_moe_backend(
|
||||
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
|
||||
|
||||
# NOTE: the kernels are selected in the following order.
|
||||
AVAILABLE_BACKENDS = [
|
||||
Fp8MoeBackend.AITER,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
Fp8MoeBackend.DEEPGEMM,
|
||||
Fp8MoeBackend.BATCHED_DEEPGEMM,
|
||||
Fp8MoeBackend.VLLM_CUTLASS,
|
||||
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
|
||||
Fp8MoeBackend.TRITON,
|
||||
Fp8MoeBackend.BATCHED_TRITON,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
Fp8MoeBackend.XPU,
|
||||
]
|
||||
AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key)
|
||||
|
||||
# NOTE(rob): We need to peak into the P/F selection to determine
|
||||
# if we are using the batched or standard expert format, which
|
||||
|
||||
Reference in New Issue
Block a user