[ROCm][Bugfix] Fix the case where there's bias (#24895)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-09-15 22:05:12 -04:00
committed by GitHub
parent de2cc3d867
commit 2891603efd
2 changed files with 32 additions and 1 deletions

View File

@@ -179,7 +179,7 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0 and bias is None:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count())
else: