[Quantization][Refactor] use platform dict to choose kernel (#33130)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> Signed-off-by: zofia <110436990+zufangzhu@users.noreply.github.com>
This commit is contained in:
@@ -33,20 +33,30 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKer
|
|||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.xpu import ( # noqa: E501
|
||||||
XPUwNa16LinearKernel,
|
XPUwNa16LinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import PlatformEnum, current_platform
|
||||||
|
|
||||||
# in priority/performance order (when available)
|
# in priority/performance order (when available)
|
||||||
_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||||
CutlassW4A8LinearKernel,
|
PlatformEnum.CUDA: [
|
||||||
MacheteLinearKernel,
|
CutlassW4A8LinearKernel,
|
||||||
AllSparkLinearKernel,
|
MacheteLinearKernel,
|
||||||
MarlinLinearKernel,
|
AllSparkLinearKernel,
|
||||||
Dynamic4bitLinearKernel,
|
MarlinLinearKernel,
|
||||||
ConchLinearKernel,
|
ConchLinearKernel,
|
||||||
ExllamaLinearKernel,
|
ExllamaLinearKernel,
|
||||||
XPUwNa16LinearKernel,
|
],
|
||||||
CPUWNA16LinearKernel,
|
PlatformEnum.ROCM: [
|
||||||
]
|
ConchLinearKernel,
|
||||||
|
ExllamaLinearKernel,
|
||||||
|
],
|
||||||
|
PlatformEnum.XPU: [
|
||||||
|
XPUwNa16LinearKernel,
|
||||||
|
],
|
||||||
|
PlatformEnum.CPU: [
|
||||||
|
Dynamic4bitLinearKernel,
|
||||||
|
CPUWNA16LinearKernel,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def choose_mp_linear_kernel(
|
def choose_mp_linear_kernel(
|
||||||
@@ -78,7 +88,7 @@ def choose_mp_linear_kernel(
|
|||||||
compute_capability = _cc[0] * 10 + _cc[1]
|
compute_capability = _cc[0] * 10 + _cc[1]
|
||||||
|
|
||||||
failure_reasons = []
|
failure_reasons = []
|
||||||
for kernel in _POSSIBLE_KERNELS:
|
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||||
failure_reasons.append(
|
failure_reasons.append(
|
||||||
f" {kernel.__name__} disabled by environment variable"
|
f" {kernel.__name__} disabled by environment variable"
|
||||||
|
|||||||
Reference in New Issue
Block a user