# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from unittest.mock import MagicMock import pytest import torch from vllm.platforms import current_platform if current_platform.is_cuda(): pytest.skip( "ROCm skinny GEMM tests are not supported on CUDA.", allow_module_level=True, ) from vllm.model_executor.layers import utils def test_rocm_unquantized_gemm_gfx1x_wvsplitk_path(monkeypatch): x = torch.randn(1, 64, dtype=torch.float16) weight = torch.randn(128, 64, dtype=torch.float16) monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False) monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True) monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True) monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False) monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False) monkeypatch.setattr(utils, "num_compute_units", lambda: 120) wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t()) monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock) llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t()) monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock) out = utils.rocm_unquantized_gemm_impl(x, weight, None) ref = torch.nn.functional.linear(x, weight, None) wvsplitk_mock.assert_called_once() llmm1_mock.assert_not_called() assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3) def test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back(monkeypatch): x = torch.randn(5, 64, dtype=torch.float16) weight = torch.randn(128, 64, dtype=torch.float16) monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False) monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True) monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True) monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False) monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False) monkeypatch.setattr(utils, "num_compute_units", lambda: 120) wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t()) monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock) llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t()) monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock) out = utils.rocm_unquantized_gemm_impl(x, weight, None) ref = torch.nn.functional.linear(x, weight, None) wvsplitk_mock.assert_not_called() llmm1_mock.assert_not_called() assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3) def test_rocm_unquantized_gemm_gfx950_wvsplitkrc_path(monkeypatch): x = torch.randn(16, 1024, dtype=torch.float16) weight = torch.randn(256, 1024, dtype=torch.float16) monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False) monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True) monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: False) monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False) monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: True) monkeypatch.setattr(utils, "num_compute_units", lambda: 120) wvsplitkrc_mock = MagicMock(side_effect=lambda x_view, w, _, __: x_view @ w.t()) monkeypatch.setattr(utils.ops, "wvSplitKrc", wvsplitkrc_mock) wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t()) monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock) out = utils.rocm_unquantized_gemm_impl(x, weight, None) ref = torch.nn.functional.linear(x, weight, None) wvsplitkrc_mock.assert_called_once() wvsplitk_mock.assert_not_called() assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)