Signed-off-by: Lalithnarayan C <Lalithnarayan.C@amd.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Chinmay-Kulkarni-AMD <Chinmay.Kulkarni@amd.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
69 lines
2.1 KiB
Python
69 lines
2.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tests for CPU unquantized GEMM dispatch behavior."""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.layers import utils
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def _mock_zentorch_linear_unary():
|
|
"""Register a mock zentorch_linear_unary op when zentorch is not installed.
|
|
|
|
Allows the dispatch tests to run in CI without a real zentorch build.
|
|
Skips registration when zentorch is already available.
|
|
"""
|
|
if hasattr(torch.ops.zentorch, "zentorch_linear_unary"):
|
|
yield
|
|
return
|
|
|
|
lib_def = torch.library.Library("zentorch", "DEF")
|
|
lib_def.define(
|
|
"zentorch_linear_unary("
|
|
"Tensor input, "
|
|
"Tensor weight, "
|
|
"Tensor? bias, "
|
|
"bool is_weight_prepacked=False"
|
|
") -> Tensor"
|
|
)
|
|
|
|
lib_impl = torch.library.Library("zentorch", "IMPL", "CPU")
|
|
lib_impl.impl(
|
|
"zentorch_linear_unary",
|
|
lambda input, weight, bias, is_weight_prepacked=False: (
|
|
torch.nn.functional.linear(input, weight, bias)
|
|
),
|
|
)
|
|
|
|
yield
|
|
|
|
lib_impl._destroy()
|
|
lib_def._destroy()
|
|
|
|
|
|
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
|
|
def test_dispatch_cpu_unquantized_gemm_uses_zentorch_on_zen(monkeypatch):
|
|
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
|
|
|
|
layer = torch.nn.Linear(16, 8, bias=True)
|
|
x = torch.randn(4, 16)
|
|
expected = torch.nn.functional.linear(x, layer.weight, layer.bias)
|
|
|
|
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=False)
|
|
output = layer.cpu_linear(x, layer.weight, layer.bias)
|
|
|
|
torch.testing.assert_close(output, expected)
|
|
|
|
|
|
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
|
|
def test_dispatch_cpu_unquantized_gemm_zen_remove_weight(monkeypatch):
|
|
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
|
|
|
|
layer = torch.nn.Linear(16, 8, bias=True)
|
|
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
|
|
|
|
assert layer.weight.numel() == 0
|