Atomics Reduce Counting Optimization for SplitK Skinny GEMMs. (#29843)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
Hashem Hashemi
2026-01-16 09:45:04 -08:00
committed by GitHub
parent 9fd918e510
commit 7a1030431a
6 changed files with 635 additions and 10 deletions

View File

@@ -8,9 +8,11 @@ import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx950
from vllm.utils.platform_utils import get_cu_count
DTYPES = [torch.bfloat16, torch.float16]
BIAS_MODES = [0, 1, 2]
# Specific (N, K, M) combinations for targeted testing
NKM_FACTORS_LLMM1 = [
# Small, medium, large cases
@@ -43,6 +45,31 @@ NKM_FACTORS_WVSPLITK = [
(4, 256, 8),
]
NKM_FACTORS_WVSPLITKRC = [
(16, 2880, 128),
(16, 2880, 640),
(17, 2880, 128),
(17, 2880, 640),
(25, 2880, 128),
(25, 2880, 640),
(31, 2880, 128),
(31, 2880, 640),
(32, 2880, 128),
(32, 2880, 640),
(40, 2880, 128),
(40, 2880, 640),
(60, 2880, 128),
(60, 2880, 640),
(64, 2880, 128),
(64, 2880, 640),
(81, 2880, 128),
(81, 2880, 640),
(98, 2880, 128),
(98, 2880, 640),
(128, 2880, 128),
(128, 2880, 640),
]
NKM_FACTORS_WVSPLITK_FP8 = [
# FP8-specific cases with K % 16 == 0
(1, 16, 16),
@@ -60,6 +87,32 @@ NKM_FACTORS_WVSPLITK_FP8 = [
SEEDS = [0]
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
def test_rocm_wvsplitkrc_kernel(n, k, m, dtype, seed, bias_mode):
torch.manual_seed(seed)
cu_count = get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
BIAS = None
if bias_mode == 1:
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
elif bias_mode == 2:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])