Atomics Reduce Counting Optimization for SplitK Skinny GEMMs. (#29843)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
@@ -129,12 +129,32 @@ def use_aiter_triton_gemm(n, m, k, dtype):
|
||||
def rocm_unquantized_gemm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
from vllm.platforms.rocm import on_gfx9, on_gfx950
|
||||
|
||||
n = x.numel() / x.size(-1)
|
||||
m = weight.shape[0]
|
||||
k = weight.shape[1]
|
||||
|
||||
import math
|
||||
|
||||
use_skinny_reduce_counting = (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_gfx950()
|
||||
and x.dtype in [torch.float16, torch.bfloat16]
|
||||
and (
|
||||
n >= 16
|
||||
and n <= 128
|
||||
and k > 512
|
||||
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
|
||||
)
|
||||
# k == 2880 and (m == 640 or m == 128))
|
||||
)
|
||||
if use_skinny_reduce_counting:
|
||||
cu_count = get_cu_count()
|
||||
x_view = x.reshape(-1, x.size(-1))
|
||||
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
|
||||
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||
|
||||
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||
|
||||
|
||||
Reference in New Issue
Block a user