diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 50b6f6315..a6cf63f22 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -13,6 +13,13 @@ #include "dispatch_utils.h" #include "quantization/w8a8/fp8/common.cuh" +// TODO(rasmith): The kernels in this file are susceptible to integer overflow +// issues, do not take strides, and are unable to handle PyTorch tensors that +// return is_contiguous() as False (the tensors may actually be contiguous +// in memory). +// +// However, it may be possible to fix these kernels to handle both issues. + #if defined(__HIPCC__) && \ (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) #define __HIP__GFX9__ diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 687105f04..7cce82073 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2027,20 +2027,35 @@ def selective_scan_fwd( ) +# NOTE: The wvSplitK kernel (and all of the kernels in skinny_gemms.cu) +# are unable to properly handle non-contiguous +# tensors. It might be a good TODO(rasmith) to augment these kernels +# to be able to handle non-contiguous kernels for better performance. +def rocm_enforce_contiguous_skinny_gemm_inputs( + a: torch.Tensor, b: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + a = a.contiguous() # no-op if already contiguous, else clone + b = b.contiguous() # no-op if already contiguous, else clone + return a, b + + # ROCm skinny gemms def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor: + a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b) return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) def wvSplitK( a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None ) -> torch.Tensor: + a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b) return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) def wvSplitKrc( a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None ) -> torch.Tensor: + a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b) return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count) @@ -2053,6 +2068,7 @@ def wvSplitKQ( cu_count: int, bias: torch.Tensor = None, ) -> torch.Tensor: + a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b) out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device) torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count) return out