[ROCm] gemm_a16w16 upstreaming (#26969)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
Aleksandr Malyshev
2025-11-04 13:01:00 -08:00
committed by GitHub
parent 1fb4217a05
commit 2d977a7a9e
2 changed files with 43 additions and 9 deletions

View File

@@ -103,12 +103,41 @@ def default_unquantized_gemm(
return torch.nn.functional.linear(x, weight, bias)
def use_aiter_triton_gemm(n, m, k, dtype):
if (
envs.VLLM_ROCM_USE_AITER == 0
# MI300's - fp8nuz=True
or current_platform.is_fp8_fnuz()
or dtype not in [torch.float16, torch.bfloat16]
):
return False
# use hipblaslt for the larger GEMMs
if n > 2048 and m > 512:
return False
return (
(m == 5120 and k == 2880)
or (m == 2880 and k == 4096)
or (m == 128 and k == 2880)
or (m == 640 and k == 2880)
or (m == 2880 and k == 512)
)
def rocm_unquantized_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9
n = x.numel() / x.size(-1)
m = weight.shape[0]
k = weight.shape[1]
if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
return gemm_a16w16(x, weight, bias)
use_skinny = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx9()
@@ -120,11 +149,8 @@ def rocm_unquantized_gemm_impl(
return torch.nn.functional.linear(x, weight, bias)
x_view = x.reshape(-1, x.size(-1))
n = x_view.shape[0]
m = weight.shape[0]
cu_count = current_platform.get_cu_count()
if m > 8 and 0 < n <= 4:
cu_count = current_platform.get_cu_count()
out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
@@ -133,7 +159,7 @@ def rocm_unquantized_gemm_impl(
return torch.nn.functional.linear(x, weight, bias)
def rocm_unquantized_gemm_impl_fake(
def rocm_unquantized_gemm_fake(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x.new_empty((*x.shape[:-1], weight.shape[0]))
@@ -145,13 +171,13 @@ def rocm_unquantized_gemm(
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias)
direct_register_custom_op(
op_name="rocm_unquantized_gemm_impl",
op_name="rocm_unquantized_gemm",
op_func=rocm_unquantized_gemm_impl,
fake_impl=rocm_unquantized_gemm_impl_fake,
fake_impl=rocm_unquantized_gemm_fake,
)