[Quantization][Refactor] use platform dict to choose kernel (#33130)

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
Signed-off-by: zofia <110436990+zufangzhu@users.noreply.github.com>
This commit is contained in:
zofia
2026-01-29 18:44:58 +08:00
committed by GitHub
parent 615e8033e5
commit a5aa4d5c0f

View File

@@ -33,20 +33,30 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer
from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501 from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501
XPUwNa16LinearKernel, XPUwNa16LinearKernel,
) )
from vllm.platforms import current_platform from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available) # in priority/performance order (when available)
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
CutlassW4A8LinearKernel, CutlassW4A8LinearKernel,
MacheteLinearKernel, MacheteLinearKernel,
AllSparkLinearKernel, AllSparkLinearKernel,
MarlinLinearKernel, MarlinLinearKernel,
Dynamic4bitLinearKernel,
ConchLinearKernel, ConchLinearKernel,
ExllamaLinearKernel, ExllamaLinearKernel,
],
PlatformEnum.ROCM: [
ConchLinearKernel,
ExllamaLinearKernel,
],
PlatformEnum.XPU: [
XPUwNa16LinearKernel, XPUwNa16LinearKernel,
],
PlatformEnum.CPU: [
Dynamic4bitLinearKernel,
CPUWNA16LinearKernel, CPUWNA16LinearKernel,
] ],
}
def choose_mp_linear_kernel( def choose_mp_linear_kernel(
@@ -78,7 +88,7 @@ def choose_mp_linear_kernel(
compute_capability = _cc[0] * 10 + _cc[1] compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = [] failure_reasons = []
for kernel in _POSSIBLE_KERNELS: for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append( failure_reasons.append(
f" {kernel.__name__} disabled by environment variable" f" {kernel.__name__} disabled by environment variable"