[ROCm] gemm_a16w16 upstreaming (#26969)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
committed by
GitHub
parent
1fb4217a05
commit
2d977a7a9e
@@ -103,12 +103,41 @@ def default_unquantized_gemm(
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def use_aiter_triton_gemm(n, m, k, dtype):
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER == 0
|
||||
# MI300's - fp8nuz=True
|
||||
or current_platform.is_fp8_fnuz()
|
||||
or dtype not in [torch.float16, torch.bfloat16]
|
||||
):
|
||||
return False
|
||||
|
||||
# use hipblaslt for the larger GEMMs
|
||||
if n > 2048 and m > 512:
|
||||
return False
|
||||
return (
|
||||
(m == 5120 and k == 2880)
|
||||
or (m == 2880 and k == 4096)
|
||||
or (m == 128 and k == 2880)
|
||||
or (m == 640 and k == 2880)
|
||||
or (m == 2880 and k == 512)
|
||||
)
|
||||
|
||||
|
||||
def rocm_unquantized_gemm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
|
||||
n = x.numel() / x.size(-1)
|
||||
m = weight.shape[0]
|
||||
k = weight.shape[1]
|
||||
|
||||
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||
|
||||
return gemm_a16w16(x, weight, bias)
|
||||
|
||||
use_skinny = (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_gfx9()
|
||||
@@ -120,11 +149,8 @@ def rocm_unquantized_gemm_impl(
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
x_view = x.reshape(-1, x.size(-1))
|
||||
n = x_view.shape[0]
|
||||
m = weight.shape[0]
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
if m > 8 and 0 < n <= 4:
|
||||
cu_count = current_platform.get_cu_count()
|
||||
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
||||
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
||||
@@ -133,7 +159,7 @@ def rocm_unquantized_gemm_impl(
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def rocm_unquantized_gemm_impl_fake(
|
||||
def rocm_unquantized_gemm_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x.new_empty((*x.shape[:-1], weight.shape[0]))
|
||||
@@ -145,13 +171,13 @@ def rocm_unquantized_gemm(
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
|
||||
return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_unquantized_gemm_impl",
|
||||
op_name="rocm_unquantized_gemm",
|
||||
op_func=rocm_unquantized_gemm_impl,
|
||||
fake_impl=rocm_unquantized_gemm_impl_fake,
|
||||
fake_impl=rocm_unquantized_gemm_fake,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user