[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode (#34709)
Signed-off-by: L.B.R. <lbr@mmonad.com> Co-authored-by: L.B.R. <lbr@mmonad.com>
This commit is contained in:
@@ -122,7 +122,7 @@ def use_aiter_triton_gemm(n, m, k, dtype):
|
||||
def rocm_unquantized_gemm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_gfx9, on_gfx950
|
||||
from vllm.platforms.rocm import on_gfx1x, on_gfx9, on_gfx950
|
||||
|
||||
n = x.numel() // x.size(-1)
|
||||
m = weight.shape[0]
|
||||
@@ -169,12 +169,12 @@ def rocm_unquantized_gemm_impl(
|
||||
|
||||
use_skinny = (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_gfx9()
|
||||
and (on_gfx9() or on_gfx1x())
|
||||
and x.dtype in [torch.float16, torch.bfloat16]
|
||||
and k % 8 == 0
|
||||
)
|
||||
|
||||
if use_skinny is not True:
|
||||
if not use_skinny:
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
x_view = x.reshape(-1, x.size(-1))
|
||||
|
||||
Reference in New Issue
Block a user