Perf tuning and expansion of cases covered for wvSplitKrc (#33493)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
@@ -145,32 +145,43 @@ def rocm_unquantized_gemm_impl(
|
||||
) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_gfx9, on_gfx950
|
||||
|
||||
n = x.numel() / x.size(-1)
|
||||
n = x.numel() // x.size(-1)
|
||||
m = weight.shape[0]
|
||||
k = weight.shape[1]
|
||||
|
||||
import math
|
||||
|
||||
cu_count = get_cu_count()
|
||||
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||
|
||||
return gemm_a16w16(x, weight, bias)
|
||||
|
||||
# Next ^2 of n
|
||||
N_p2 = 1 << (n - 1).bit_length()
|
||||
# With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
|
||||
# and each working on a 512-shard of K, how many CUs would we need?
|
||||
rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
|
||||
# How many of 4 waves in a group can work on same 16 Ms at same time?
|
||||
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
|
||||
GrpsShrB = min(N_p2 // 16, 4)
|
||||
# Given the above, how many CUs would we need?
|
||||
CuNeeded = rndup_cus * GrpsShrB
|
||||
# candidate for atomic reduce count splitk?
|
||||
fits_wvsplitkrc = CuNeeded <= cu_count
|
||||
|
||||
use_skinny_reduce_counting = (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_gfx950()
|
||||
and x.dtype in [torch.float16, torch.bfloat16]
|
||||
and (
|
||||
n >= 16
|
||||
and n <= 128
|
||||
10 <= n <= 128
|
||||
and k % 8 == 0
|
||||
and k > 512
|
||||
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
|
||||
and m % 16 == 0
|
||||
and fits_wvsplitkrc
|
||||
and x.is_contiguous()
|
||||
)
|
||||
# k == 2880 and (m == 640 or m == 128))
|
||||
)
|
||||
if use_skinny_reduce_counting:
|
||||
cu_count = get_cu_count()
|
||||
x_view = x.reshape(-1, x.size(-1))
|
||||
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
|
||||
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||
|
||||
Reference in New Issue
Block a user