[CI][AMD][BugFix] Update wvSplitK (and other skinny_gemm wrappers) to ensure tensors passed will be made contiguous for the kernel (#32831)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
@@ -13,6 +13,13 @@
|
|||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "quantization/w8a8/fp8/common.cuh"
|
#include "quantization/w8a8/fp8/common.cuh"
|
||||||
|
|
||||||
|
// TODO(rasmith): The kernels in this file are susceptible to integer overflow
|
||||||
|
// issues, do not take strides, and are unable to handle PyTorch tensors that
|
||||||
|
// return is_contiguous() as False (the tensors may actually be contiguous
|
||||||
|
// in memory).
|
||||||
|
//
|
||||||
|
// However, it may be possible to fix these kernels to handle both issues.
|
||||||
|
|
||||||
#if defined(__HIPCC__) && \
|
#if defined(__HIPCC__) && \
|
||||||
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
|
(defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__))
|
||||||
#define __HIP__GFX9__
|
#define __HIP__GFX9__
|
||||||
|
|||||||
@@ -2027,20 +2027,35 @@ def selective_scan_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: The wvSplitK kernel (and all of the kernels in skinny_gemms.cu)
|
||||||
|
# are unable to properly handle non-contiguous
|
||||||
|
# tensors. It might be a good TODO(rasmith) to augment these kernels
|
||||||
|
# to be able to handle non-contiguous kernels for better performance.
|
||||||
|
def rocm_enforce_contiguous_skinny_gemm_inputs(
|
||||||
|
a: torch.Tensor, b: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
a = a.contiguous() # no-op if already contiguous, else clone
|
||||||
|
b = b.contiguous() # no-op if already contiguous, else clone
|
||||||
|
return a, b
|
||||||
|
|
||||||
|
|
||||||
# ROCm skinny gemms
|
# ROCm skinny gemms
|
||||||
def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor:
|
def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor:
|
||||||
|
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
|
||||||
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
|
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)
|
||||||
|
|
||||||
|
|
||||||
def wvSplitK(
|
def wvSplitK(
|
||||||
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
|
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
|
||||||
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)
|
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)
|
||||||
|
|
||||||
|
|
||||||
def wvSplitKrc(
|
def wvSplitKrc(
|
||||||
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
|
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
|
||||||
return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count)
|
return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count)
|
||||||
|
|
||||||
|
|
||||||
@@ -2053,6 +2068,7 @@ def wvSplitKQ(
|
|||||||
cu_count: int,
|
cu_count: int,
|
||||||
bias: torch.Tensor = None,
|
bias: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
|
||||||
out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device)
|
out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device)
|
||||||
torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
|
torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
|
||||||
return out
|
return out
|
||||||
|
|||||||
Reference in New Issue
Block a user