[ROCm][Kernel] Add gfx950 support for skinny gemms (#18010)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-05-31 09:40:05 -05:00
committed by GitHub
parent f2c3f66d59
commit 306d60401d
5 changed files with 91 additions and 54 deletions

View File

@@ -155,8 +155,8 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
from vllm.platforms.rocm import on_mi250_mi300
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())