[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode (#34709)

Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: L.B.R. <lbr@mmonad.com>
This commit is contained in:
L.B.R.
2026-03-20 15:11:23 +00:00
committed by GitHub
parent 44eea10f68
commit 1779c09898
4 changed files with 365 additions and 99 deletions

View File

@@ -122,7 +122,7 @@ def use_aiter_triton_gemm(n, m, k, dtype):
def rocm_unquantized_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9, on_gfx950
from vllm.platforms.rocm import on_gfx1x, on_gfx9, on_gfx950
n = x.numel() // x.size(-1)
m = weight.shape[0]
@@ -169,12 +169,12 @@ def rocm_unquantized_gemm_impl(
use_skinny = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx9()
and (on_gfx9() or on_gfx1x())
and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0
)
if use_skinny is not True:
if not use_skinny:
return torch.nn.functional.linear(x, weight, bias)
x_view = x.reshape(-1, x.size(-1))