diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 9030fbcf3..93706e0b1 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -33,20 +33,30 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501 XPUwNa16LinearKernel, ) -from vllm.platforms import current_platform +from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) -_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ - CutlassW4A8LinearKernel, - MacheteLinearKernel, - AllSparkLinearKernel, - MarlinLinearKernel, - Dynamic4bitLinearKernel, - ConchLinearKernel, - ExllamaLinearKernel, - XPUwNa16LinearKernel, - CPUWNA16LinearKernel, -] +_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = { + PlatformEnum.CUDA: [ + CutlassW4A8LinearKernel, + MacheteLinearKernel, + AllSparkLinearKernel, + MarlinLinearKernel, + ConchLinearKernel, + ExllamaLinearKernel, + ], + PlatformEnum.ROCM: [ + ConchLinearKernel, + ExllamaLinearKernel, + ], + PlatformEnum.XPU: [ + XPUwNa16LinearKernel, + ], + PlatformEnum.CPU: [ + Dynamic4bitLinearKernel, + CPUWNA16LinearKernel, + ], +} def choose_mp_linear_kernel( @@ -78,7 +88,7 @@ def choose_mp_linear_kernel( compute_capability = _cc[0] * 10 + _cc[1] failure_reasons = [] - for kernel in _POSSIBLE_KERNELS: + for kernel in _POSSIBLE_KERNELS[current_platform._enum]: if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: failure_reasons.append( f" {kernel.__name__} disabled by environment variable"