[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode (#34709)
Signed-off-by: L.B.R. <lbr@mmonad.com> Co-authored-by: L.B.R. <lbr@mmonad.com>
This commit is contained in:
@@ -160,6 +160,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, padded_a, bias_mode
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
|
||||
elif bias_mode == 2:
|
||||
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
|
||||
elif bias_mode == 3:
|
||||
BIAS = torch.rand(1, m, dtype=dtype, device="cuda") * 2 - 1
|
||||
|
||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||
out = ops.wvSplitKrc(A, B, cu_count, BIAS)
|
||||
@@ -224,10 +226,9 @@ def test_rocm_wvsplitk_kernel(
|
||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
|
||||
|
||||
if xnorm:
|
||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
|
||||
else:
|
||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
|
||||
# Accumulation error in fp16 GEMM scales with sqrt(K)
|
||||
atol = torch.finfo(dtype).eps * math.sqrt(k)
|
||||
torch.testing.assert_close(out, ref_out, atol=atol, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("xnorm", [False, True])
|
||||
|
||||
Reference in New Issue
Block a user