Files
vllm/tests/model_executor/test_cpu_unquantized_gemm_dispatch.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

69 lines
2.1 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for CPU unquantized GEMM dispatch behavior."""
import pytest
import torch
from vllm.model_executor.layers import utils
from vllm.platforms import current_platform
@pytest.fixture(scope="module")
def _mock_zentorch_linear_unary():
"""Register a mock zentorch_linear_unary op when zentorch is not installed.
Allows the dispatch tests to run in CI without a real zentorch build.
Skips registration when zentorch is already available.
"""
if hasattr(torch.ops.zentorch, "zentorch_linear_unary"):
yield
return
lib_def = torch.library.Library("zentorch", "DEF")
lib_def.define(
"zentorch_linear_unary("
"Tensor input, "
"Tensor weight, "
"Tensor? bias, "
"bool is_weight_prepacked=False"
") -> Tensor"
)
lib_impl = torch.library.Library("zentorch", "IMPL", "CPU")
lib_impl.impl(
"zentorch_linear_unary",
lambda input, weight, bias, is_weight_prepacked=False: (
torch.nn.functional.linear(input, weight, bias)
),
)
yield
lib_impl._destroy()
lib_def._destroy()
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
def test_dispatch_cpu_unquantized_gemm_uses_zentorch_on_zen(monkeypatch):
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
layer = torch.nn.Linear(16, 8, bias=True)
x = torch.randn(4, 16)
expected = torch.nn.functional.linear(x, layer.weight, layer.bias)
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=False)
output = layer.cpu_linear(x, layer.weight, layer.bias)
torch.testing.assert_close(output, expected)
@pytest.mark.usefixtures("_mock_zentorch_linear_unary")
def test_dispatch_cpu_unquantized_gemm_zen_remove_weight(monkeypatch):
monkeypatch.setattr(current_platform, "is_zen_cpu", lambda: True)
layer = torch.nn.Linear(16, 8, bias=True)
utils.dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
assert layer.weight.numel() == 0