In-Tree AMD Zen CPU Backend via zentorch [1/N] (#35970)
Signed-off-by: Lalithnarayan C <Lalithnarayan.C@amd.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Chinmay-Kulkarni-AMD <Chinmay.Kulkarni@amd.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
68
tests/model_executor/test_cpu_unquantized_gemm_dispatch.py
Normal file
68
tests/model_executor/test_cpu_unquantized_gemm_dispatch.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for CPU unquantized GEMM dispatch behavior."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers import utils
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def _mock_zentorch_linear_unary():
|
||||
"""Register a mock zentorch_linear_unary op when zentorch is not installed.
|
||||
|
||||
Allows the dispatch tests to run in CI without a real zentorch build.
|
||||
Skips registration when zentorch is already available.
|
||||
"""
|
||||
if hasattr(torch.ops.zentorch, "zentorch_linear_unary"):
|
||||
yield
|
||||
return
|
||||
|
||||
lib_def = torch.library.Library("zentorch", "DEF")
|
||||
lib_def.define(
|
||||
"zentorch_linear_unary("
|
||||
"Tensor input, "
|
||||
"Tensor weight, "
|
||||
"Tensor? bias, "
|
||||
"bool is_weight_prepacked=False"
|
||||
") -> Tensor"
|
||||
)
|
||||
|
||||
lib_impl = torch.library.Library("zentorch", "IMPL", "CPU")
|
||||
lib_impl.impl(
|
||||
"zentorch_linear_unary",
|
||||
lambda input, weight, bias, is_weight_prepacked=False: (
|
||||
torch.nn.functional.linear(input, weight, bias)
|
||||
),
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
lib_impl._destroy()
|
||||
lib_def._destroy()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
|
||||
def test_dispatch_cpu_unquantized_gemm_uses_zentorch_on_zen(monkeypatch):
|
||||
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
|
||||
|
||||
layer = torch.nn.Linear(16, 8, bias=True)
|
||||
x = torch.randn(4, 16)
|
||||
expected = torch.nn.functional.linear(x, layer.weight, layer.bias)
|
||||
|
||||
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=False)
|
||||
output = layer.cpu_linear(x, layer.weight, layer.bias)
|
||||
|
||||
torch.testing.assert_close(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
|
||||
def test_dispatch_cpu_unquantized_gemm_zen_remove_weight(monkeypatch):
|
||||
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
|
||||
|
||||
layer = torch.nn.Linear(16, 8, bias=True)
|
||||
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
|
||||
|
||||
assert layer.weight.numel() == 0
|
||||
Reference in New Issue
Block a user