[Bugfix][ROCm] Fixing the skinny gemm dispatch logic from #32831 (#33366)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2026-01-30 19:05:23 -06:00
committed by GitHub
parent 67ebaff528
commit 31aedfe7d6
4 changed files with 18 additions and 18 deletions

View File

@@ -28,6 +28,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A.shape[0] == 1
and B.shape[1] % 16 == 0
and ((bias is None) or (bias.dtype == out_dtype))
and A.is_contiguous()
):
output = ops.wvSplitKQ(
B.t(),

View File

@@ -165,6 +165,7 @@ def rocm_unquantized_gemm_impl(
and n <= 128
and k > 512
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
and x.is_contiguous()
)
# k == 2880 and (m == 640 or m == 128))
)
@@ -179,6 +180,7 @@ def rocm_unquantized_gemm_impl(
and on_gfx9()
and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0
and x.is_contiguous()
)
if use_skinny is not True: