[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode (#34709)

Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: L.B.R. <lbr@mmonad.com>
This commit is contained in:
L.B.R.
2026-03-20 15:11:23 +00:00
committed by GitHub
parent 44eea10f68
commit 1779c09898
4 changed files with 365 additions and 99 deletions

View File

@@ -160,6 +160,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, padded_a, bias_mode
BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
elif bias_mode == 2:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
elif bias_mode == 3:
BIAS = torch.rand(1, m, dtype=dtype, device="cuda") * 2 - 1
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitKrc(A, B, cu_count, BIAS)
@@ -224,10 +226,9 @@ def test_rocm_wvsplitk_kernel(
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
# Accumulation error in fp16 GEMM scales with sqrt(K)
atol = torch.finfo(dtype).eps * math.sqrt(k)
torch.testing.assert_close(out, ref_out, atol=atol, rtol=1e-2)
@pytest.mark.parametrize("xnorm", [False, True])