[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode (#34709)
Signed-off-by: L.B.R. <lbr@mmonad.com> Co-authored-by: L.B.R. <lbr@mmonad.com>
This commit is contained in:
89
tests/model_executor/layers/test_rocm_unquantized_gemm.py
Normal file
89
tests/model_executor/layers/test_rocm_unquantized_gemm.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
pytest.skip(
|
||||
"ROCm skinny GEMM tests are not supported on CUDA.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers import utils
|
||||
|
||||
|
||||
def test_rocm_unquantized_gemm_gfx1x_wvsplitk_path(monkeypatch):
|
||||
x = torch.randn(1, 64, dtype=torch.float16)
|
||||
weight = torch.randn(128, 64, dtype=torch.float16)
|
||||
|
||||
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
|
||||
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
|
||||
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True)
|
||||
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
|
||||
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False)
|
||||
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
|
||||
|
||||
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
|
||||
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
|
||||
llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t())
|
||||
monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock)
|
||||
|
||||
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
|
||||
ref = torch.nn.functional.linear(x, weight, None)
|
||||
|
||||
wvsplitk_mock.assert_called_once()
|
||||
llmm1_mock.assert_not_called()
|
||||
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back(monkeypatch):
|
||||
x = torch.randn(5, 64, dtype=torch.float16)
|
||||
weight = torch.randn(128, 64, dtype=torch.float16)
|
||||
|
||||
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
|
||||
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
|
||||
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True)
|
||||
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
|
||||
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False)
|
||||
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
|
||||
|
||||
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
|
||||
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
|
||||
llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t())
|
||||
monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock)
|
||||
|
||||
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
|
||||
ref = torch.nn.functional.linear(x, weight, None)
|
||||
|
||||
wvsplitk_mock.assert_not_called()
|
||||
llmm1_mock.assert_not_called()
|
||||
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def test_rocm_unquantized_gemm_gfx950_wvsplitkrc_path(monkeypatch):
|
||||
x = torch.randn(16, 1024, dtype=torch.float16)
|
||||
weight = torch.randn(256, 1024, dtype=torch.float16)
|
||||
|
||||
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
|
||||
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
|
||||
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: False)
|
||||
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
|
||||
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: True)
|
||||
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
|
||||
|
||||
wvsplitkrc_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
|
||||
monkeypatch.setattr(utils.ops, "wvSplitKrc", wvsplitkrc_mock)
|
||||
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
|
||||
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
|
||||
|
||||
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
|
||||
ref = torch.nn.functional.linear(x, weight, None)
|
||||
|
||||
wvsplitkrc_mock.assert_called_once()
|
||||
wvsplitk_mock.assert_not_called()
|
||||
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
|
||||
Reference in New Issue
Block a user