In-Tree AMD Zen CPU Backend via zentorch [1/N] (#35970)
Signed-off-by: Lalithnarayan C <Lalithnarayan.C@amd.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Chinmay-Kulkarni-AMD <Chinmay.Kulkarni@amd.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#
|
||||
# Build targets:
|
||||
# vllm-openai (default): used for serving deployment
|
||||
# vllm-openai-zen: vLLM from source + zentorch from PyPI via vllm[zen]
|
||||
# vllm-test: used for CI tests
|
||||
# vllm-dev: used for development
|
||||
#
|
||||
@@ -222,3 +223,19 @@ LABEL ai.vllm.build.cpu-arm-bf16="${VLLM_CPU_ARM_BF16:-false}"
|
||||
LABEL ai.vllm.build.python-version="${PYTHON_VERSION:-3.12}"
|
||||
|
||||
ENTRYPOINT ["vllm", "serve"]
|
||||
|
||||
|
||||
######################### ZEN CPU PYPI IMAGE #########################
|
||||
FROM vllm-openai AS vllm-openai-zen
|
||||
|
||||
ARG TARGETARCH
|
||||
|
||||
RUN if [ "$TARGETARCH" != "amd64" ]; then \
|
||||
echo "ERROR: vllm-openai-amd only supports --platform=linux/amd64"; \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install "vllm[zen]"
|
||||
|
||||
ENTRYPOINT ["vllm", "serve"]
|
||||
|
||||
2
setup.py
2
setup.py
@@ -966,6 +966,8 @@ setup(
|
||||
ext_modules=ext_modules,
|
||||
install_requires=get_requirements(),
|
||||
extras_require={
|
||||
# AMD Zen CPU optimizations via zentorch
|
||||
"zen": ["zentorch"],
|
||||
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.2.2"],
|
||||
|
||||
68
tests/model_executor/test_cpu_unquantized_gemm_dispatch.py
Normal file
68
tests/model_executor/test_cpu_unquantized_gemm_dispatch.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for CPU unquantized GEMM dispatch behavior."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers import utils
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def _mock_zentorch_linear_unary():
|
||||
"""Register a mock zentorch_linear_unary op when zentorch is not installed.
|
||||
|
||||
Allows the dispatch tests to run in CI without a real zentorch build.
|
||||
Skips registration when zentorch is already available.
|
||||
"""
|
||||
if hasattr(torch.ops.zentorch, "zentorch_linear_unary"):
|
||||
yield
|
||||
return
|
||||
|
||||
lib_def = torch.library.Library("zentorch", "DEF")
|
||||
lib_def.define(
|
||||
"zentorch_linear_unary("
|
||||
"Tensor input, "
|
||||
"Tensor weight, "
|
||||
"Tensor? bias, "
|
||||
"bool is_weight_prepacked=False"
|
||||
") -> Tensor"
|
||||
)
|
||||
|
||||
lib_impl = torch.library.Library("zentorch", "IMPL", "CPU")
|
||||
lib_impl.impl(
|
||||
"zentorch_linear_unary",
|
||||
lambda input, weight, bias, is_weight_prepacked=False: (
|
||||
torch.nn.functional.linear(input, weight, bias)
|
||||
),
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
lib_impl._destroy()
|
||||
lib_def._destroy()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
|
||||
def test_dispatch_cpu_unquantized_gemm_uses_zentorch_on_zen(monkeypatch):
|
||||
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
|
||||
|
||||
layer = torch.nn.Linear(16, 8, bias=True)
|
||||
x = torch.randn(4, 16)
|
||||
expected = torch.nn.functional.linear(x, layer.weight, layer.bias)
|
||||
|
||||
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=False)
|
||||
output = layer.cpu_linear(x, layer.weight, layer.bias)
|
||||
|
||||
torch.testing.assert_close(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
|
||||
def test_dispatch_cpu_unquantized_gemm_zen_remove_weight(monkeypatch):
|
||||
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
|
||||
|
||||
layer = torch.nn.Linear(16, 8, bias=True)
|
||||
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
|
||||
|
||||
assert layer.weight.numel() == 0
|
||||
37
tests/test_zen_cpu_platform_detection.py
Normal file
37
tests/test_zen_cpu_platform_detection.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from vllm.platforms import _is_amd_zen_cpu
|
||||
|
||||
|
||||
def test_is_amd_zen_cpu_detects_amd_with_avx512():
|
||||
cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2 avx512f avx512bw"
|
||||
with (
|
||||
patch("os.path.exists", return_value=True),
|
||||
patch("builtins.open", mock_open(read_data=cpuinfo)),
|
||||
):
|
||||
assert _is_amd_zen_cpu()
|
||||
|
||||
|
||||
def test_is_amd_zen_cpu_returns_false_for_amd_without_avx512():
|
||||
cpuinfo = "vendor_id: AuthenticAMD\nflags: avx avx2"
|
||||
with (
|
||||
patch("os.path.exists", return_value=True),
|
||||
patch("builtins.open", mock_open(read_data=cpuinfo)),
|
||||
):
|
||||
assert not _is_amd_zen_cpu()
|
||||
|
||||
|
||||
def test_is_amd_zen_cpu_returns_false_for_intel_with_avx512():
|
||||
cpuinfo = "vendor_id: GenuineIntel\nflags: avx avx2 avx512f"
|
||||
with (
|
||||
patch("os.path.exists", return_value=True),
|
||||
patch("builtins.open", mock_open(read_data=cpuinfo)),
|
||||
):
|
||||
assert not _is_amd_zen_cpu()
|
||||
|
||||
|
||||
def test_is_amd_zen_cpu_returns_false_when_cpuinfo_missing():
|
||||
with patch("os.path.exists", return_value=False):
|
||||
assert not _is_amd_zen_cpu()
|
||||
@@ -51,6 +51,7 @@ if TYPE_CHECKING:
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = "auto"
|
||||
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
|
||||
VLLM_CPU_SGL_KERNEL: bool = False
|
||||
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
|
||||
@@ -709,6 +710,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
else None,
|
||||
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
|
||||
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
|
||||
# (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout
|
||||
# at model load time. Eliminates per-inference layout conversion overhead.
|
||||
"VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool(
|
||||
int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1"))
|
||||
),
|
||||
# If the env var is set, Ray Compiled Graph uses the specified
|
||||
# channel type to communicate between workers belonging to
|
||||
# different pipeline-parallel stages.
|
||||
@@ -1768,6 +1774,7 @@ def compile_factors() -> dict[str, object]:
|
||||
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
|
||||
"VLLM_CPU_KVCACHE_SPACE",
|
||||
"VLLM_CPU_MOE_PREPACK",
|
||||
"VLLM_ZENTORCH_WEIGHT_PREPACK",
|
||||
"VLLM_TEST_FORCE_LOAD_FORMAT",
|
||||
"VLLM_ENABLE_CUDA_COMPATIBILITY",
|
||||
"VLLM_CUDA_COMPATIBILITY_PATH",
|
||||
|
||||
@@ -231,6 +231,30 @@ def dispatch_cpu_unquantized_gemm(
|
||||
N, K = layer.weight.size()
|
||||
dtype = layer.weight.dtype
|
||||
|
||||
# Zen CPU path: zentorch_linear_unary with optional eager weight prepacking.
|
||||
if current_platform.is_zen_cpu() and hasattr(
|
||||
torch.ops.zentorch, "zentorch_linear_unary"
|
||||
):
|
||||
zen_weight = layer.weight.detach()
|
||||
is_prepacked = False
|
||||
|
||||
if envs.VLLM_ZENTORCH_WEIGHT_PREPACK and hasattr(
|
||||
torch.ops.zentorch, "zentorch_weight_prepack_for_linear"
|
||||
):
|
||||
zen_weight = torch.ops.zentorch.zentorch_weight_prepack_for_linear(
|
||||
zen_weight
|
||||
)
|
||||
is_prepacked = True
|
||||
|
||||
layer.cpu_linear = lambda x, weight, bias, _p=is_prepacked: (
|
||||
torch.ops.zentorch.zentorch_linear_unary(
|
||||
x, zen_weight, bias, is_weight_prepacked=_p
|
||||
)
|
||||
)
|
||||
if remove_weight:
|
||||
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
||||
return
|
||||
|
||||
if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
|
||||
packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
|
||||
if getattr(layer, "bias", None) is not None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -150,6 +151,15 @@ def xpu_platform_plugin() -> str | None:
|
||||
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
|
||||
|
||||
|
||||
def _is_amd_zen_cpu() -> bool:
|
||||
"""Detect AMD CPU with AVX-512 via /proc/cpuinfo."""
|
||||
if not os.path.exists("/proc/cpuinfo"):
|
||||
return False
|
||||
with open("/proc/cpuinfo") as f:
|
||||
cpuinfo = f.read()
|
||||
return "AuthenticAMD" in cpuinfo and "avx512" in cpuinfo
|
||||
|
||||
|
||||
def cpu_platform_plugin() -> str | None:
|
||||
is_cpu = False
|
||||
logger.debug("Checking if CPU platform is available.")
|
||||
@@ -171,7 +181,24 @@ def cpu_platform_plugin() -> str | None:
|
||||
except Exception as e:
|
||||
logger.debug("CPU platform is not available because: %s", str(e))
|
||||
|
||||
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
|
||||
if not is_cpu:
|
||||
return None
|
||||
|
||||
if _is_amd_zen_cpu():
|
||||
try:
|
||||
import zentorch # noqa: F401
|
||||
|
||||
logger.debug(
|
||||
"AMD Zen CPU detected with zentorch installed, using ZenCpuPlatform."
|
||||
)
|
||||
return "vllm.platforms.zen_cpu.ZenCpuPlatform"
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
"AMD Zen CPU detected but zentorch not installed, "
|
||||
"falling back to CpuPlatform."
|
||||
)
|
||||
|
||||
return "vllm.platforms.cpu.CpuPlatform"
|
||||
|
||||
|
||||
builtin_platform_plugins = {
|
||||
@@ -269,4 +296,11 @@ def __setattr__(name: str, value):
|
||||
raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
|
||||
|
||||
|
||||
__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"]
|
||||
__all__ = [
|
||||
"Platform",
|
||||
"PlatformEnum",
|
||||
"current_platform",
|
||||
"CpuArchEnum",
|
||||
"_init_trace",
|
||||
"_is_amd_zen_cpu",
|
||||
]
|
||||
|
||||
@@ -167,6 +167,9 @@ class Platform:
|
||||
def is_cpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.CPU
|
||||
|
||||
def is_zen_cpu(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_out_of_tree(self) -> bool:
|
||||
return self._enum == PlatformEnum.OOT
|
||||
|
||||
|
||||
67
vllm/platforms/zen_cpu.py
Normal file
67
vllm/platforms/zen_cpu.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class ZenCpuPlatform(CpuPlatform):
|
||||
"""CPU platform with AMD Zen (ZenDNN/zentorch) optimizations.
|
||||
|
||||
Model-load time (dispatch_cpu_unquantized_gemm in layers/utils.py):
|
||||
- Routes linear ops to zentorch_linear_unary.
|
||||
- When VLLM_ZENTORCH_WEIGHT_PREPACK=1 (default), eagerly prepacks
|
||||
weights via zentorch_weight_prepack_for_linear.
|
||||
"""
|
||||
|
||||
device_name: str = "cpu"
|
||||
device_type: str = "cpu"
|
||||
|
||||
def is_zen_cpu(self) -> bool:
|
||||
# is_cpu() also returns True for this platform (inherited from CpuPlatform).
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
super().check_and_update_config(vllm_config)
|
||||
cls._apply_pytorch_backports()
|
||||
|
||||
@classmethod
|
||||
def _apply_pytorch_backports(cls):
|
||||
"""Backport PyTorch mainline fixes missing in 2.10.
|
||||
|
||||
PyTorch 2.10 has a bug in FxGraphCachePickler.dumps that doesn't
|
||||
catch ValueError, causing torch.compile cache misses. Remove this
|
||||
once we drop PyTorch 2.10 support. PT mainline already has this fix.
|
||||
"""
|
||||
if not is_torch_equal_or_newer("2.10.0") or is_torch_equal_or_newer("2.11.0"):
|
||||
return
|
||||
|
||||
cls._patch_fxgraphcache_pickle()
|
||||
|
||||
@classmethod
|
||||
def _patch_fxgraphcache_pickle(cls):
|
||||
"""Backport mainline ValueError fix to FxGraphCachePickler.dumps()."""
|
||||
from torch._inductor.codecache import BypassFxGraphCache, FxGraphCachePickler
|
||||
|
||||
original_dumps = FxGraphCachePickler.dumps
|
||||
if hasattr(original_dumps, "_zen_patched"):
|
||||
return
|
||||
|
||||
def patched_dumps(self, obj):
|
||||
try:
|
||||
return original_dumps(self, obj)
|
||||
except ValueError as e:
|
||||
raise BypassFxGraphCache("Failed to pickle cache key") from e
|
||||
|
||||
patched_dumps._zen_patched = True # type: ignore[attr-defined]
|
||||
FxGraphCachePickler.dumps = patched_dumps
|
||||
logger.info("[zen_cpu] Patched FxGraphCachePickler.dumps (ValueError fix)")
|
||||
Reference in New Issue
Block a user