Perf tuning and expansion of cases covered for wvSplitKrc (#33493)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
Hashem Hashemi
2026-02-07 05:33:11 -08:00
committed by GitHub
parent 860981d8d8
commit ed17f54c8b
3 changed files with 214 additions and 223 deletions

View File

@@ -45,31 +45,28 @@ NKM_FACTORS_WVSPLITK = [
(4, 256, 8),
]
NKM_FACTORS_WVSPLITKRC = [
(16, 2880, 128),
(16, 2880, 640),
(17, 2880, 128),
(17, 2880, 640),
(25, 2880, 128),
(25, 2880, 640),
(31, 2880, 128),
(31, 2880, 640),
(32, 2880, 128),
(32, 2880, 640),
(40, 2880, 128),
(40, 2880, 640),
(60, 2880, 128),
(60, 2880, 640),
(64, 2880, 128),
(64, 2880, 640),
(81, 2880, 128),
(81, 2880, 640),
(98, 2880, 128),
(98, 2880, 640),
(128, 2880, 128),
(128, 2880, 640),
N_FACTORS_WVSPLITKRC = [
13,
16,
17,
25,
29,
31,
32,
41,
51,
64,
71,
81,
91,
103,
117,
128,
]
K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8]
M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16]
NKM_FACTORS_WVSPLITK_FP8 = [
# FP8-specific cases with K % 16 == 0
(1, 16, 16),
@@ -113,30 +110,54 @@ def pad_fp8(weight):
return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n", N_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("k", K_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
def test_rocm_wvsplitkrc_kernel(n, k, m, dtype, seed, bias_mode):
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
torch.manual_seed(seed)
cu_count = get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
# Next ^2 of n
N_p2 = 1 << (n - 1).bit_length()
# With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
# and each working on a 512-shard of K, how many CUs would we need?
rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
# How many of 4 waves in a group can work on same 16 Ms at same time?
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
GrpsShrB = min(N_p2 // 16, 4)
# Given the above, how many CUs would we need?
CuNeeded = rndup_cus * GrpsShrB
# candidate for atomic reduce count splitk?
fits_wvsplitkrc = CuNeeded <= cu_count
if not fits_wvsplitkrc:
pytest.skip("Too large for wvSplitKrc")
xavier = (
math.sqrt(2 / k) if xnorm else 1
) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
BIAS = None
if bias_mode == 1:
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
elif bias_mode == 2:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS)
assert torch.allclose(out, ref_out, rtol=0.01)
if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)