In-Tree AMD Zen CPU Backend via zentorch [1/N] (#35970)

Signed-off-by: Lalithnarayan C <Lalithnarayan.C@amd.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Chinmay-Kulkarni-AMD <Chinmay.Kulkarni@amd.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Lalithnarayan C
2026-03-16 05:05:35 +05:30
committed by GitHub
parent 697e4ff352
commit 7acaea634c
9 changed files with 261 additions and 2 deletions

View File

@@ -9,6 +9,7 @@
#
# Build targets:
# vllm-openai (default): used for serving deployment
# vllm-openai-zen: vLLM from source + zentorch from PyPI via vllm[zen]
# vllm-test: used for CI tests
# vllm-dev: used for development
#
@@ -222,3 +223,19 @@ LABEL ai.vllm.build.cpu-arm-bf16="${VLLM_CPU_ARM_BF16:-false}"
LABEL ai.vllm.build.python-version="${PYTHON_VERSION:-3.12}"
ENTRYPOINT ["vllm", "serve"]
######################### ZEN CPU PYPI IMAGE #########################
FROM vllm-openai AS vllm-openai-zen
ARG TARGETARCH
RUN if [ "$TARGETARCH" != "amd64" ]; then \
echo "ERROR: vllm-openai-amd only supports --platform=linux/amd64"; \
exit 1; \
fi
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install "vllm[zen]"
ENTRYPOINT ["vllm", "serve"]

View File

@@ -966,6 +966,8 @@ setup(
ext_modules=ext_modules,
install_requires=get_requirements(),
extras_require={
# AMD Zen CPU optimizations via zentorch
"zen": ["zentorch"],
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.2.2"],

View File

@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for CPU unquantized GEMM dispatch behavior."""
import pytest
import torch
from vllm.model_executor.layers import utils
from vllm.platforms import current_platform
@pytest.fixture(scope="module")
def _mock_zentorch_linear_unary():
"""Register a mock zentorch_linear_unary op when zentorch is not installed.
Allows the dispatch tests to run in CI without a real zentorch build.
Skips registration when zentorch is already available.
"""
if hasattr(torch.ops.zentorch, "zentorch_linear_unary"):
yield
return
lib_def = torch.library.Library("zentorch", "DEF")
lib_def.define(
"zentorch_linear_unary("
"Tensor input, "
"Tensor weight, "
"Tensor? bias, "
"bool is_weight_prepacked=False"
") -> Tensor"
)
lib_impl = torch.library.Library("zentorch", "IMPL", "CPU")
lib_impl.impl(
"zentorch_linear_unary",
lambda input, weight, bias, is_weight_prepacked=False: (
torch.nn.functional.linear(input, weight, bias)
),
)
yield
lib_impl._destroy()
lib_def._destroy()
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
def test_dispatch_cpu_unquantized_gemm_uses_zentorch_on_zen(monkeypatch):
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
layer = torch.nn.Linear(16, 8, bias=True)
x = torch.randn(4, 16)
expected = torch.nn.functional.linear(x, layer.weight, layer.bias)
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=False)
output = layer.cpu_linear(x, layer.weight, layer.bias)
torch.testing.assert_close(output, expected)
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
def test_dispatch_cpu_unquantized_gemm_zen_remove_weight(monkeypatch):
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
layer = torch.nn.Linear(16, 8, bias=True)
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
assert layer.weight.numel() == 0

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import mock_open, patch
from vllm.platforms import _is_amd_zen_cpu
def test_is_amd_zen_cpu_detects_amd_with_avx512():
cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2 avx512f avx512bw"
with (
patch("os.path.exists", return_value=True),
patch("builtins.open", mock_open(read_data=cpuinfo)),
):
assert _is_amd_zen_cpu()
def test_is_amd_zen_cpu_returns_false_for_amd_without_avx512():
cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2"
with (
patch("os.path.exists", return_value=True),
patch("builtins.open", mock_open(read_data=cpuinfo)),
):
assert not _is_amd_zen_cpu()
def test_is_amd_zen_cpu_returns_false_for_intel_with_avx512():
cpuinfo = "vendor_id: GenuineIntel\nflags: avx avx2 avx512f"
with (
patch("os.path.exists", return_value=True),
patch("builtins.open", mock_open(read_data=cpuinfo)),
):
assert not _is_amd_zen_cpu()
def test_is_amd_zen_cpu_returns_false_when_cpuinfo_missing():
with patch("os.path.exists", return_value=False):
assert not _is_amd_zen_cpu()

View File

@@ -51,6 +51,7 @@ if TYPE_CHECKING:
VLLM_CPU_OMP_THREADS_BIND: str = "auto"
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
VLLM_CPU_SGL_KERNEL: bool = False
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
@@ -709,6 +710,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
else None,
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
# (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout
# at model load time. Eliminates per-inference layout conversion overhead.
"VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool(
int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1"))
),
# If the env var is set, Ray Compiled Graph uses the specified
# channel type to communicate between workers belonging to
# different pipeline-parallel stages.
@@ -1768,6 +1774,7 @@ def compile_factors() -> dict[str, object]:
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
"VLLM_CPU_KVCACHE_SPACE",
"VLLM_CPU_MOE_PREPACK",
"VLLM_ZENTORCH_WEIGHT_PREPACK",
"VLLM_TEST_FORCE_LOAD_FORMAT",
"VLLM_ENABLE_CUDA_COMPATIBILITY",
"VLLM_CUDA_COMPATIBILITY_PATH",

View File

@@ -231,6 +231,30 @@ def dispatch_cpu_unquantized_gemm(
N, K = layer.weight.size()
dtype = layer.weight.dtype
# Zen CPU path: zentorch_linear_unary with optional eager weight prepacking.
if current_platform.is_zen_cpu() and hasattr(
torch.ops.zentorch, "zentorch_linear_unary"
):
zen_weight = layer.weight.detach()
is_prepacked = False
if envs.VLLM_ZENTORCH_WEIGHT_PREPACK and hasattr(
torch.ops.zentorch, "zentorch_weight_prepack_for_linear"
):
zen_weight = torch.ops.zentorch.zentorch_weight_prepack_for_linear(
zen_weight
)
is_prepacked = True
layer.cpu_linear = lambda x, weight, bias, _p=is_prepacked: (
torch.ops.zentorch.zentorch_linear_unary(
x, zen_weight, bias, is_weight_prepacked=_p
)
)
if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
return
if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
if getattr(layer, "bias", None) is not None:

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
import traceback
from itertools import chain
from typing import TYPE_CHECKING
@@ -150,6 +151,15 @@ def xpu_platform_plugin() -> str | None:
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
def _is_amd_zen_cpu() -> bool:
"""Detect AMD CPU with AVX-512 via /proc/cpuinfo."""
if not os.path.exists("/proc/cpuinfo"):
return False
with open("/proc/cpuinfo") as f:
cpuinfo = f.read()
return "AuthenticAMD" in cpuinfo and "avx512" in cpuinfo
def cpu_platform_plugin() -> str | None:
is_cpu = False
logger.debug("Checking if CPU platform is available.")
@@ -171,7 +181,24 @@ def cpu_platform_plugin() -> str | None:
except Exception as e:
logger.debug("CPU platform is not available because: %s", str(e))
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
if not is_cpu:
return None
if _is_amd_zen_cpu():
try:
import zentorch # noqa: F401
logger.debug(
"AMD Zen CPU detected with zentorch installed, using ZenCpuPlatform."
)
return "vllm.platforms.zen_cpu.ZenCpuPlatform"
except ImportError:
logger.debug(
"AMD Zen CPU detected but zentorch not installed, "
"falling back to CpuPlatform."
)
return "vllm.platforms.cpu.CpuPlatform"
builtin_platform_plugins = {
@@ -269,4 +296,11 @@ def __setattr__(name: str, value):
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"]
__all__ = [
"Platform",
"PlatformEnum",
"current_platform",
"CpuArchEnum",
"_init_trace",
"_is_amd_zen_cpu",
]

View File

@@ -167,6 +167,9 @@ class Platform:
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU
def is_zen_cpu(self) -> bool:
return False
def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT

67
vllm/platforms/zen_cpu.py Normal file
View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.platforms.cpu import CpuPlatform
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.config import VllmConfig
class ZenCpuPlatform(CpuPlatform):
"""CPU platform with AMD Zen (ZenDNN/zentorch) optimizations.
Model-load time (dispatch_cpu_unquantized_gemm in layers/utils.py):
- Routes linear ops to zentorch_linear_unary.
- When VLLM_ZENTORCH_WEIGHT_PREPACK=1 (default), eagerly prepacks
weights via zentorch_weight_prepack_for_linear.
"""
device_name: str = "cpu"
device_type: str = "cpu"
def is_zen_cpu(self) -> bool:
# is_cpu() also returns True for this platform (inherited from CpuPlatform).
return True
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
super().check_and_update_config(vllm_config)
cls._apply_pytorch_backports()
@classmethod
def _apply_pytorch_backports(cls):
"""Backport PyTorch mainline fixes missing in 2.10.
PyTorch 2.10 has a bug in FxGraphCachePickler.dumps that doesn't
catch ValueError, causing torch.compile cache misses. Remove this
once we drop PyTorch 2.10 support. PT mainline already has this fix.
"""
if not is_torch_equal_or_newer("2.10.0") or is_torch_equal_or_newer("2.11.0"):
return
cls._patch_fxgraphcache_pickle()
@classmethod
def _patch_fxgraphcache_pickle(cls):
"""Backport mainline ValueError fix to FxGraphCachePickler.dumps()."""
from torch._inductor.codecache import BypassFxGraphCache, FxGraphCachePickler
original_dumps = FxGraphCachePickler.dumps
if hasattr(original_dumps, "_zen_patched"):
return
def patched_dumps(self, obj):
try:
return original_dumps(self, obj)
except ValueError as e:
raise BypassFxGraphCache("Failed to pickle cache key") from e
patched_dumps._zen_patched = True # type: ignore[attr-defined]
FxGraphCachePickler.dumps = patched_dumps
logger.info("[zen_cpu] Patched FxGraphCachePickler.dumps (ValueError fix)")