In-Tree AMD Zen CPU Backend via zentorch [1/N] (#35970)
Signed-off-by: Lalithnarayan C <Lalithnarayan.C@amd.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Chinmay-Kulkarni-AMD <Chinmay.Kulkarni@amd.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -9,6 +9,7 @@
|
|||||||
#
|
#
|
||||||
# Build targets:
|
# Build targets:
|
||||||
# vllm-openai (default): used for serving deployment
|
# vllm-openai (default): used for serving deployment
|
||||||
|
# vllm-openai-zen: vLLM from source + zentorch from PyPI via vllm[zen]
|
||||||
# vllm-test: used for CI tests
|
# vllm-test: used for CI tests
|
||||||
# vllm-dev: used for development
|
# vllm-dev: used for development
|
||||||
#
|
#
|
||||||
@@ -222,3 +223,19 @@ LABEL ai.vllm.build.cpu-arm-bf16="${VLLM_CPU_ARM_BF16:-false}"
|
|||||||
LABEL ai.vllm.build.python-version="${PYTHON_VERSION:-3.12}"
|
LABEL ai.vllm.build.python-version="${PYTHON_VERSION:-3.12}"
|
||||||
|
|
||||||
ENTRYPOINT ["vllm", "serve"]
|
ENTRYPOINT ["vllm", "serve"]
|
||||||
|
|
||||||
|
|
||||||
|
######################### ZEN CPU PYPI IMAGE #########################
|
||||||
|
FROM vllm-openai AS vllm-openai-zen
|
||||||
|
|
||||||
|
ARG TARGETARCH
|
||||||
|
|
||||||
|
RUN if [ "$TARGETARCH" != "amd64" ]; then \
|
||||||
|
echo "ERROR: vllm-openai-amd only supports --platform=linux/amd64"; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv pip install "vllm[zen]"
|
||||||
|
|
||||||
|
ENTRYPOINT ["vllm", "serve"]
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -966,6 +966,8 @@ setup(
|
|||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
install_requires=get_requirements(),
|
install_requires=get_requirements(),
|
||||||
extras_require={
|
extras_require={
|
||||||
|
# AMD Zen CPU optimizations via zentorch
|
||||||
|
"zen": ["zentorch"],
|
||||||
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
|
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
|
||||||
"tensorizer": ["tensorizer==2.10.1"],
|
"tensorizer": ["tensorizer==2.10.1"],
|
||||||
"fastsafetensors": ["fastsafetensors >= 0.2.2"],
|
"fastsafetensors": ["fastsafetensors >= 0.2.2"],
|
||||||
|
|||||||
68
tests/model_executor/test_cpu_unquantized_gemm_dispatch.py
Normal file
68
tests/model_executor/test_cpu_unquantized_gemm_dispatch.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests for CPU unquantized GEMM dispatch behavior."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers import utils
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def _mock_zentorch_linear_unary():
|
||||||
|
"""Register a mock zentorch_linear_unary op when zentorch is not installed.
|
||||||
|
|
||||||
|
Allows the dispatch tests to run in CI without a real zentorch build.
|
||||||
|
Skips registration when zentorch is already available.
|
||||||
|
"""
|
||||||
|
if hasattr(torch.ops.zentorch, "zentorch_linear_unary"):
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
lib_def = torch.library.Library("zentorch", "DEF")
|
||||||
|
lib_def.define(
|
||||||
|
"zentorch_linear_unary("
|
||||||
|
"Tensor input, "
|
||||||
|
"Tensor weight, "
|
||||||
|
"Tensor? bias, "
|
||||||
|
"bool is_weight_prepacked=False"
|
||||||
|
") -> Tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
lib_impl = torch.library.Library("zentorch", "IMPL", "CPU")
|
||||||
|
lib_impl.impl(
|
||||||
|
"zentorch_linear_unary",
|
||||||
|
lambda input, weight, bias, is_weight_prepacked=False: (
|
||||||
|
torch.nn.functional.linear(input, weight, bias)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
lib_impl._destroy()
|
||||||
|
lib_def._destroy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
|
||||||
|
def test_dispatch_cpu_unquantized_gemm_uses_zentorch_on_zen(monkeypatch):
|
||||||
|
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
|
||||||
|
|
||||||
|
layer = torch.nn.Linear(16, 8, bias=True)
|
||||||
|
x = torch.randn(4, 16)
|
||||||
|
expected = torch.nn.functional.linear(x, layer.weight, layer.bias)
|
||||||
|
|
||||||
|
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=False)
|
||||||
|
output = layer.cpu_linear(x, layer.weight, layer.bias)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
|
||||||
|
def test_dispatch_cpu_unquantized_gemm_zen_remove_weight(monkeypatch):
|
||||||
|
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
|
||||||
|
|
||||||
|
layer = torch.nn.Linear(16, 8, bias=True)
|
||||||
|
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
|
||||||
|
|
||||||
|
assert layer.weight.numel() == 0
|
||||||
37
tests/test_zen_cpu_platform_detection.py
Normal file
37
tests/test_zen_cpu_platform_detection.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from unittest.mock import mock_open, patch
|
||||||
|
|
||||||
|
from vllm.platforms import _is_amd_zen_cpu
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_amd_zen_cpu_detects_amd_with_avx512():
|
||||||
|
cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2 avx512f avx512bw"
|
||||||
|
with (
|
||||||
|
patch("os.path.exists", return_value=True),
|
||||||
|
patch("builtins.open", mock_open(read_data=cpuinfo)),
|
||||||
|
):
|
||||||
|
assert _is_amd_zen_cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_amd_zen_cpu_returns_false_for_amd_without_avx512():
|
||||||
|
cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2"
|
||||||
|
with (
|
||||||
|
patch("os.path.exists", return_value=True),
|
||||||
|
patch("builtins.open", mock_open(read_data=cpuinfo)),
|
||||||
|
):
|
||||||
|
assert not _is_amd_zen_cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_amd_zen_cpu_returns_false_for_intel_with_avx512():
|
||||||
|
cpuinfo = "vendor_id: GenuineIntel\nflags: avx avx2 avx512f"
|
||||||
|
with (
|
||||||
|
patch("os.path.exists", return_value=True),
|
||||||
|
patch("builtins.open", mock_open(read_data=cpuinfo)),
|
||||||
|
):
|
||||||
|
assert not _is_amd_zen_cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_amd_zen_cpu_returns_false_when_cpuinfo_missing():
|
||||||
|
with patch("os.path.exists", return_value=False):
|
||||||
|
assert not _is_amd_zen_cpu()
|
||||||
@@ -51,6 +51,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_CPU_OMP_THREADS_BIND: str = "auto"
|
VLLM_CPU_OMP_THREADS_BIND: str = "auto"
|
||||||
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
|
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
|
||||||
VLLM_CPU_SGL_KERNEL: bool = False
|
VLLM_CPU_SGL_KERNEL: bool = False
|
||||||
|
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
|
||||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||||
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
|
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
|
||||||
@@ -709,6 +710,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
else None,
|
else None,
|
||||||
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
|
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
|
||||||
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
|
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
|
||||||
|
# (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout
|
||||||
|
# at model load time. Eliminates per-inference layout conversion overhead.
|
||||||
|
"VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool(
|
||||||
|
int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1"))
|
||||||
|
),
|
||||||
# If the env var is set, Ray Compiled Graph uses the specified
|
# If the env var is set, Ray Compiled Graph uses the specified
|
||||||
# channel type to communicate between workers belonging to
|
# channel type to communicate between workers belonging to
|
||||||
# different pipeline-parallel stages.
|
# different pipeline-parallel stages.
|
||||||
@@ -1768,6 +1774,7 @@ def compile_factors() -> dict[str, object]:
|
|||||||
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
|
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
|
||||||
"VLLM_CPU_KVCACHE_SPACE",
|
"VLLM_CPU_KVCACHE_SPACE",
|
||||||
"VLLM_CPU_MOE_PREPACK",
|
"VLLM_CPU_MOE_PREPACK",
|
||||||
|
"VLLM_ZENTORCH_WEIGHT_PREPACK",
|
||||||
"VLLM_TEST_FORCE_LOAD_FORMAT",
|
"VLLM_TEST_FORCE_LOAD_FORMAT",
|
||||||
"VLLM_ENABLE_CUDA_COMPATIBILITY",
|
"VLLM_ENABLE_CUDA_COMPATIBILITY",
|
||||||
"VLLM_CUDA_COMPATIBILITY_PATH",
|
"VLLM_CUDA_COMPATIBILITY_PATH",
|
||||||
|
|||||||
@@ -231,6 +231,30 @@ def dispatch_cpu_unquantized_gemm(
|
|||||||
N, K = layer.weight.size()
|
N, K = layer.weight.size()
|
||||||
dtype = layer.weight.dtype
|
dtype = layer.weight.dtype
|
||||||
|
|
||||||
|
# Zen CPU path: zentorch_linear_unary with optional eager weight prepacking.
|
||||||
|
if current_platform.is_zen_cpu() and hasattr(
|
||||||
|
torch.ops.zentorch, "zentorch_linear_unary"
|
||||||
|
):
|
||||||
|
zen_weight = layer.weight.detach()
|
||||||
|
is_prepacked = False
|
||||||
|
|
||||||
|
if envs.VLLM_ZENTORCH_WEIGHT_PREPACK and hasattr(
|
||||||
|
torch.ops.zentorch, "zentorch_weight_prepack_for_linear"
|
||||||
|
):
|
||||||
|
zen_weight = torch.ops.zentorch.zentorch_weight_prepack_for_linear(
|
||||||
|
zen_weight
|
||||||
|
)
|
||||||
|
is_prepacked = True
|
||||||
|
|
||||||
|
layer.cpu_linear = lambda x, weight, bias, _p=is_prepacked: (
|
||||||
|
torch.ops.zentorch.zentorch_linear_unary(
|
||||||
|
x, zen_weight, bias, is_weight_prepacked=_p
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if remove_weight:
|
||||||
|
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
||||||
|
return
|
||||||
|
|
||||||
if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
|
if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
|
||||||
packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
|
packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
|
||||||
if getattr(layer, "bias", None) is not None:
|
if getattr(layer, "bias", None) is not None:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
@@ -150,6 +151,15 @@ def xpu_platform_plugin() -> str | None:
|
|||||||
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
|
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_amd_zen_cpu() -> bool:
|
||||||
|
"""Detect AMD CPU with AVX-512 via /proc/cpuinfo."""
|
||||||
|
if not os.path.exists("/proc/cpuinfo"):
|
||||||
|
return False
|
||||||
|
with open("/proc/cpuinfo") as f:
|
||||||
|
cpuinfo = f.read()
|
||||||
|
return "AuthenticAMD" in cpuinfo and "avx512" in cpuinfo
|
||||||
|
|
||||||
|
|
||||||
def cpu_platform_plugin() -> str | None:
|
def cpu_platform_plugin() -> str | None:
|
||||||
is_cpu = False
|
is_cpu = False
|
||||||
logger.debug("Checking if CPU platform is available.")
|
logger.debug("Checking if CPU platform is available.")
|
||||||
@@ -171,7 +181,24 @@ def cpu_platform_plugin() -> str | None:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("CPU platform is not available because: %s", str(e))
|
logger.debug("CPU platform is not available because: %s", str(e))
|
||||||
|
|
||||||
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
|
if not is_cpu:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if _is_amd_zen_cpu():
|
||||||
|
try:
|
||||||
|
import zentorch # noqa: F401
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"AMD Zen CPU detected with zentorch installed, using ZenCpuPlatform."
|
||||||
|
)
|
||||||
|
return "vllm.platforms.zen_cpu.ZenCpuPlatform"
|
||||||
|
except ImportError:
|
||||||
|
logger.debug(
|
||||||
|
"AMD Zen CPU detected but zentorch not installed, "
|
||||||
|
"falling back to CpuPlatform."
|
||||||
|
)
|
||||||
|
|
||||||
|
return "vllm.platforms.cpu.CpuPlatform"
|
||||||
|
|
||||||
|
|
||||||
builtin_platform_plugins = {
|
builtin_platform_plugins = {
|
||||||
@@ -269,4 +296,11 @@ def __setattr__(name: str, value):
|
|||||||
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"]
|
__all__ = [
|
||||||
|
"Platform",
|
||||||
|
"PlatformEnum",
|
||||||
|
"current_platform",
|
||||||
|
"CpuArchEnum",
|
||||||
|
"_init_trace",
|
||||||
|
"_is_amd_zen_cpu",
|
||||||
|
]
|
||||||
|
|||||||
@@ -167,6 +167,9 @@ class Platform:
|
|||||||
def is_cpu(self) -> bool:
|
def is_cpu(self) -> bool:
|
||||||
return self._enum == PlatformEnum.CPU
|
return self._enum == PlatformEnum.CPU
|
||||||
|
|
||||||
|
def is_zen_cpu(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def is_out_of_tree(self) -> bool:
|
def is_out_of_tree(self) -> bool:
|
||||||
return self._enum == PlatformEnum.OOT
|
return self._enum == PlatformEnum.OOT
|
||||||
|
|
||||||
|
|||||||
67
vllm/platforms/zen_cpu.py
Normal file
67
vllm/platforms/zen_cpu.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms.cpu import CpuPlatform
|
||||||
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ZenCpuPlatform(CpuPlatform):
|
||||||
|
"""CPU platform with AMD Zen (ZenDNN/zentorch) optimizations.
|
||||||
|
|
||||||
|
Model-load time (dispatch_cpu_unquantized_gemm in layers/utils.py):
|
||||||
|
- Routes linear ops to zentorch_linear_unary.
|
||||||
|
- When VLLM_ZENTORCH_WEIGHT_PREPACK=1 (default), eagerly prepacks
|
||||||
|
weights via zentorch_weight_prepack_for_linear.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device_name: str = "cpu"
|
||||||
|
device_type: str = "cpu"
|
||||||
|
|
||||||
|
def is_zen_cpu(self) -> bool:
|
||||||
|
# is_cpu() also returns True for this platform (inherited from CpuPlatform).
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||||
|
super().check_and_update_config(vllm_config)
|
||||||
|
cls._apply_pytorch_backports()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _apply_pytorch_backports(cls):
|
||||||
|
"""Backport PyTorch mainline fixes missing in 2.10.
|
||||||
|
|
||||||
|
PyTorch 2.10 has a bug in FxGraphCachePickler.dumps that doesn't
|
||||||
|
catch ValueError, causing torch.compile cache misses. Remove this
|
||||||
|
once we drop PyTorch 2.10 support. PT mainline already has this fix.
|
||||||
|
"""
|
||||||
|
if not is_torch_equal_or_newer("2.10.0") or is_torch_equal_or_newer("2.11.0"):
|
||||||
|
return
|
||||||
|
|
||||||
|
cls._patch_fxgraphcache_pickle()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _patch_fxgraphcache_pickle(cls):
|
||||||
|
"""Backport mainline ValueError fix to FxGraphCachePickler.dumps()."""
|
||||||
|
from torch._inductor.codecache import BypassFxGraphCache, FxGraphCachePickler
|
||||||
|
|
||||||
|
original_dumps = FxGraphCachePickler.dumps
|
||||||
|
if hasattr(original_dumps, "_zen_patched"):
|
||||||
|
return
|
||||||
|
|
||||||
|
def patched_dumps(self, obj):
|
||||||
|
try:
|
||||||
|
return original_dumps(self, obj)
|
||||||
|
except ValueError as e:
|
||||||
|
raise BypassFxGraphCache("Failed to pickle cache key") from e
|
||||||
|
|
||||||
|
patched_dumps._zen_patched = True # type: ignore[attr-defined]
|
||||||
|
FxGraphCachePickler.dumps = patched_dumps
|
||||||
|
logger.info("[zen_cpu] Patched FxGraphCachePickler.dumps (ValueError fix)")
|
||||||
Reference in New Issue
Block a user