Perf tuning and expansion of cases covered for wvSplitKrc (#33493)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
Hashem Hashemi
2026-02-07 05:33:11 -08:00
committed by GitHub
parent 860981d8d8
commit ed17f54c8b
3 changed files with 214 additions and 223 deletions

View File

@@ -145,32 +145,43 @@ def rocm_unquantized_gemm_impl(
) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9, on_gfx950
n = x.numel() / x.size(-1)
n = x.numel() // x.size(-1)
m = weight.shape[0]
k = weight.shape[1]
import math
cu_count = get_cu_count()
if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
return gemm_a16w16(x, weight, bias)
# Next ^2 of n
N_p2 = 1 << (n - 1).bit_length()
# With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
# and each working on a 512-shard of K, how many CUs would we need?
rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
# How many of 4 waves in a group can work on same 16 Ms at same time?
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
GrpsShrB = min(N_p2 // 16, 4)
# Given the above, how many CUs would we need?
CuNeeded = rndup_cus * GrpsShrB
# candidate for atomic reduce count splitk?
fits_wvsplitkrc = CuNeeded <= cu_count
use_skinny_reduce_counting = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx950()
and x.dtype in [torch.float16, torch.bfloat16]
and (
n >= 16
and n <= 128
10 <= n <= 128
and k % 8 == 0
and k > 512
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
and m % 16 == 0
and fits_wvsplitkrc
and x.is_contiguous()
)
# k == 2880 and (m == 640 or m == 128))
)
if use_skinny_reduce_counting:
cu_count = get_cu_count()
x_view = x.reshape(-1, x.size(-1))
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0])