[Quantization][Refactor] use platform dict to choose kernel (#33130)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> Signed-off-by: zofia <110436990+zufangzhu@users.noreply.github.com>
This commit is contained in:
@@ -33,20 +33,30 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
CutlassW4A8LinearKernel,
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
Dynamic4bitLinearKernel,
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.XPU: [
|
||||
XPUwNa16LinearKernel,
|
||||
],
|
||||
PlatformEnum.CPU: [
|
||||
Dynamic4bitLinearKernel,
|
||||
CPUWNA16LinearKernel,
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
@@ -78,7 +88,7 @@ def choose_mp_linear_kernel(
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS:
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} disabled by environment variable"
|
||||
|
||||
Reference in New Issue
Block a user