From a5aa4d5c0f31bba0491a2d9328785dd39dac33c0 Mon Sep 17 00:00:00 2001 From: zofia <110436990+zufangzhu@users.noreply.github.com> Date: Thu, 29 Jan 2026 18:44:58 +0800 Subject: [PATCH] [Quantization][Refactor] use platform dict to choose kernel (#33130) Signed-off-by: Zhu, Zufang Signed-off-by: zofia <110436990+zufangzhu@users.noreply.github.com> --- .../kernels/mixed_precision/__init__.py | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 9030fbcf3..93706e0b1 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -33,20 +33,30 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501 XPUwNa16LinearKernel, ) -from vllm.platforms import current_platform +from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) -_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ - CutlassW4A8LinearKernel, - MacheteLinearKernel, - AllSparkLinearKernel, - MarlinLinearKernel, - Dynamic4bitLinearKernel, - ConchLinearKernel, - ExllamaLinearKernel, - XPUwNa16LinearKernel, - CPUWNA16LinearKernel, -] +_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = { + PlatformEnum.CUDA: [ + CutlassW4A8LinearKernel, + MacheteLinearKernel, + AllSparkLinearKernel, + MarlinLinearKernel, + ConchLinearKernel, + ExllamaLinearKernel, + ], + PlatformEnum.ROCM: [ + ConchLinearKernel, + ExllamaLinearKernel, + ], + PlatformEnum.XPU: [ + XPUwNa16LinearKernel, + ], + PlatformEnum.CPU: [ + Dynamic4bitLinearKernel, + CPUWNA16LinearKernel, + ], +} def choose_mp_linear_kernel( @@ -78,7 +88,7 @@ def choose_mp_linear_kernel( compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] - for kernel in _POSSIBLE_KERNELS: + for kernel in _POSSIBLE_KERNELS[current_platform._enum]: if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: failure_reasons.append( f" {kernel.__name__} disabled by environment variable"