[CI][AMD][BugFix] Update wvSplitK (and other skinny_gemm wrappers) to ensure tensors passed will be made contiguous for the kernel (#32831)

Signed-off-by: Randall Smith <ransmith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
rasmith
2026-01-23 15:35:48 -06:00
committed by GitHub
parent dfab5f3764
commit 6cc6d92be5
2 changed files with 23 additions and 0 deletions

View File

@@ -2027,20 +2027,35 @@ def selective_scan_fwd(
)
# NOTE: The wvSplitK kernel (and all of the kernels in skinny_gemms.cu)
# are unable to properly handle non-contiguous
# tensors. It might be a good TODO(rasmith) to augment these kernels
# to be able to handle non-contiguous kernels for better performance.
def rocm_enforce_contiguous_skinny_gemm_inputs(
a: torch.Tensor, b: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
a = a.contiguous() # no-op if already contiguous, else clone
b = b.contiguous() # no-op if already contiguous, else clone
return a, b
# ROCm skinny gemms
def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
def wvSplitK(
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)
def wvSplitKrc(
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count)
@@ -2053,6 +2068,7 @@ def wvSplitKQ(
cu_count: int,
bias: torch.Tensor = None,
) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
return out