[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode (#34709)

Signed-off-by: L.B.R. <lbr@mmonad.com>
Co-authored-by: L.B.R. <lbr@mmonad.com>
This commit is contained in:
L.B.R.
2026-03-20 15:11:23 +00:00
committed by GitHub
parent 44eea10f68
commit 1779c09898
4 changed files with 365 additions and 99 deletions

View File

@@ -26,6 +26,16 @@
#define __HIP__GFX9__
#endif
#if defined(__HIPCC__) && \
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1150__) || \
defined(__gfx1151__) || defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX1X__
#endif
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
#define __HIP__GFX12__
#endif
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
#define __HIP__MI3XX__
#endif
@@ -37,15 +47,31 @@
#endif
int get_lds_size() {
static bool is_cached = false;
static int result;
if (is_cached == false) {
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string device_arch = dprops->gcnArchName;
size_t substring = device_arch.find("gfx95");
result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024);
is_cached = true;
static const int result = [] {
const auto* dprops = at::cuda::getCurrentDeviceProperties();
const std::string device_arch = dprops->gcnArchName;
return device_arch.find("gfx95") == std::string::npos ? 64 * 1024
: 160 * 1024;
}();
return result;
}
bool on_gfx1x() {
static const bool result = [] {
const auto* dprops = at::cuda::getCurrentDeviceProperties();
const std::string device_arch = dprops->gcnArchName;
return device_arch.find("gfx11") != std::string::npos ||
device_arch.find("gfx12") != std::string::npos;
}();
return result;
}
bool on_gfx12() {
static const bool result = [] {
const auto* dprops = at::cuda::getCurrentDeviceProperties();
const std::string device_arch = dprops->gcnArchName;
return device_arch.find("gfx12") != std::string::npos;
}();
return result;
}
@@ -286,21 +312,35 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
return out_c;
}
#if defined(__HIP__GFX9__) && !defined(__HIP__GFX1X__)
#define DOT2C(V0, V2, V3) \
if constexpr (std::is_same_v<scalar_t, half>) { \
asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \
asm("v_dot2c_f32_f16 %0, %2, %3" \
: "=v"(V0) \
: "0"(V0), "v"(V2), "v"(V3)); \
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \
float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \
__bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \
V0 += (s.x + s.y); \
}
#elif defined(__HIP__GFX1X__)
// gfx1x: v_dot2_f32_f16 (VOP3-P, dot10-insts, available on gfx11+gfx12)
#define DOT2C(V0, V2, V3) \
if constexpr (std::is_same_v<scalar_t, half>) { \
asm("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(V0) : "v"(V2), "v"(V3)); \
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \
float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \
__bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \
V0 += (s.x + s.y); \
}
#endif
// To avoid LLVM silently upcasting to double
__device__ inline unsigned int min__(uint32_t a, uint32_t b) {
return min(a, b);
}
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) || defined(__HIP__GFX1X__)
// This version targets cases where A[] fits LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
@@ -442,14 +482,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
1); // row_shr2
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
1); // row_shr1
#if defined(__HIP__GFX9__)
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
1); // ROW_BCAST15
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
1); // ROW_BCAST31
#else
sum[n][y] += __shfl_xor(sum[n][y], 16);
#endif
}
}
if (threadIdx.x == 63) {
if (threadIdx.x == (THRDS - 1)) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
@@ -469,6 +513,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
} else {
#ifdef __HIP__GFX9__
#pragma unroll
for (int n = 0; n < N; n++) {
#pragma unroll
@@ -498,7 +543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
if (threadIdx.x == (THRDS - 1)) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
@@ -513,11 +558,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
#endif // __HIP__GFX9__ (MFMA path)
}
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
#else
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap,
@@ -528,9 +574,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#endif
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) || defined(__HIP__GFX1X__)
// This version targets cases where A[] marginally exceeds LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
@@ -657,14 +703,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
1); // row_shr2
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
1); // row_shr1
#if defined(__HIP__GFX9__)
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
1); // ROW_BCAST15
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
1); // ROW_BCAST31
#else
sum[n][y] += __shfl_xor(sum[n][y], 16);
#endif
}
}
if (threadIdx.x == 63) {
if (threadIdx.x == (THRDS - 1)) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
@@ -686,6 +736,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
} else {
#ifdef __HIP__GFX9__
#pragma unroll
for (int n = 0; n < N; n++) {
#pragma unroll
@@ -713,7 +764,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
if (threadIdx.x == (THRDS - 1)) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
@@ -730,6 +781,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
#endif // __HIP__GFX9__ (MFMA path)
}
m += CuCount * _WvPrGrp * YTILE;
@@ -746,7 +798,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
#else
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap,
@@ -756,9 +808,9 @@ __global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#endif
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
#if defined(__HIP__GFX9__) || defined(__HIP__GFX1X__)
// This version targets big A[] cases, where it is much larger than LDS capacity
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
@@ -1004,14 +1056,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
1); // row_shr2
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
1); // row_shr1
#if defined(__HIP__GFX9__)
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
1); // ROW_BCAST15
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
1); // ROW_BCAST31
#else
sum[n][y] += __shfl_xor(sum[n][y], 16);
#endif
}
}
if (threadIdx.x == 63) {
if (threadIdx.x == (THRDS - 1)) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
@@ -1033,6 +1089,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
} else {
#ifdef __HIP__GFX9__
#pragma unroll
for (int n = 0; n < N; n++) {
#pragma unroll
@@ -1057,7 +1114,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
if (threadIdx.x == (THRDS - 1)) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
@@ -1074,6 +1131,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
#endif // __HIP__GFX9__ (MFMA path)
}
m += CuCount * _WvPrGrp * YTILE;
@@ -1090,7 +1148,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
#else
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap,
@@ -1101,7 +1159,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
#endif
// Find the min val of div2 that doesn't increase N/(div1*div2)
int mindiv(int N, int div1, int div2) {
@@ -1148,40 +1206,40 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_YTILE, _UNRL, _N) \
#define WVSPLITK_CFG(_THRDS, _WVPRGRP, _YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
dim3 block(_THRDS, _WVPRGRP); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, _WVPRGRP); \
if ((Kbp_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
wvSplitK_hf_sml_<fptype, _THRDS, _YTILE, _WVPRGRP, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
else if (Kbp_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
wvSplitK_hf_<fptype, _THRDS, _YTILE, _WVPRGRP, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
wvSplitK_hf_big_<fptype, _THRDS, _YTILE, _WVPRGRP, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
}
#define WVSPLIT_TILE(_sYT, __N) \
#define WVSPLIT_TILE_CFG(_THRDS, _WVPRGRP, _sYT, __N) \
{ \
bool fit_lds = (Kbp_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
WVSPLITK_CFG(_THRDS, _WVPRGRP, 1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
WVSPLITK(2, 2, __N) \
WVSPLITK_CFG(_THRDS, _WVPRGRP, 2, 2, __N) \
else if (_sYT <= 4 * 3) \
WVSPLITK(3, 2, __N) \
WVSPLITK_CFG(_THRDS, _WVPRGRP, 3, 2, __N) \
else if (__N == 4) \
WVSPLITK(4, 1, __N) \
WVSPLITK_CFG(_THRDS, _WVPRGRP, 4, 1, __N) \
else \
WVSPLITK(4, 2, __N) \
WVSPLITK_CFG(_THRDS, _WVPRGRP, 4, 2, __N) \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
@@ -1198,18 +1256,31 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
// then cut the active waves to balance their distribution...
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
const bool use_wave32 = on_gfx1x();
switch (N_in) {
case 1:
WVSPLIT_TILE(sYT, 1)
if (use_wave32)
WVSPLIT_TILE_CFG(32, 16, sYT, 1)
else
WVSPLIT_TILE_CFG(64, 16, sYT, 1)
break;
case 2:
WVSPLIT_TILE(sYT, 2)
if (use_wave32)
WVSPLIT_TILE_CFG(32, 16, sYT, 2)
else
WVSPLIT_TILE_CFG(64, 16, sYT, 2)
break;
case 3:
WVSPLIT_TILE(sYT, 3)
if (use_wave32)
WVSPLIT_TILE_CFG(32, 16, sYT, 3)
else
WVSPLIT_TILE_CFG(64, 16, sYT, 3)
break;
case 4:
WVSPLIT_TILE(sYT, 4)
if (use_wave32)
WVSPLIT_TILE_CFG(32, 16, sYT, 4)
else
WVSPLIT_TILE_CFG(64, 16, sYT, 4)
break;
default:
throw std::runtime_error(
@@ -1653,7 +1724,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#endif
}
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
#else
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
__global__ void wvSplitKrc_(const int actlN, const int K, const int Kap,
@@ -1688,6 +1759,8 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
TORCH_CHECK(in_a.dtype() == torch::kFloat16 ||
in_a.dtype() == torch::kBFloat16);
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
auto out_c = torch::empty(
{N_in, M_in},
torch::TensorOptions().dtype(in_a.dtype()).device(in_a.device()));
@@ -1696,7 +1769,6 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
dim3 grid(CuCount);
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// const int max_lds_len = get_lds_size() / 2;
@@ -1773,7 +1845,7 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
return out_c;
}
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
@@ -1817,12 +1889,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
float sA = *s_A;
float sB = *s_B;
while (m < M) {
#ifdef __HIP__GFX12__
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
float sum[N][YTILE] = {};
#else
// gfx9: MFMA accumulation
scalar8 sum[N][YTILE] = {};
#endif
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
@@ -1854,6 +1931,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
for (uint32_t n = 0; n < N; n++) {
#ifdef __HIP__GFX12__
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
for (int y = 0; y < YTILE; ++y) {
#pragma unroll
for (int i = 0; i < A_CHUNK / 4; i++) {
sum[n][y] = __builtin_amdgcn_dot4_f32_fp8_fp8(
bigA[n][k2].i[i], bigB[y][k2].i[i], sum[n][y]);
}
}
#else
// gfx9: MFMA path
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
@@ -1861,11 +1949,33 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
0);
}
}
#endif
}
}
}
// Final reduction
#ifdef __HIP__GFX12__
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:1 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
sum[n][y] += __shfl_xor(sum[n][y], 16);
}
}
#else
// gfx9 MFMA reduction
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
@@ -1880,8 +1990,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
sum[n][y][0] = accm0;
}
}
#endif
if (threadIdx.x == 0) {
const bool writeback_lane =
#ifdef __HIP__GFX12__
threadIdx.x == (THRDS - 1);
#else
threadIdx.x == 0;
#endif
if (writeback_lane) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
@@ -1892,13 +2009,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
sum[n][y][0] *= sA * sB;
#ifdef __HIP__GFX12__
float result = sum[n][y] * sA * sB;
#else
float result = sum[n][y][0] * sA * sB;
#endif
if constexpr (std::is_same_v<scalar_t, half>) {
sum[n][y][0] += __half2float(biases[n][y]);
result += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
sum[n][y][0] += __bfloat162float(biases[n][y]);
result += __bfloat162float(biases[n][y]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
C[m + y + n * M] = __float2s<scalar_t>(result);
}
}
}
@@ -1906,7 +2027,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
@@ -1918,9 +2039,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
@@ -1963,12 +2084,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
float sA = *s_A;
float sB = *s_B;
while (m < M) {
#ifdef __HIP__GFX12__
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
float sum[N][YTILE] = {};
#else
// gfx9: MFMA accumulation
scalar8 sum[N][YTILE] = {};
#endif
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
@@ -2002,6 +2128,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
for (uint32_t n = 0; n < N; n++) {
#ifdef __HIP__GFX12__
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
for (int y = 0; y < YTILE; ++y) {
#pragma unroll
for (int i = 0; i < A_CHUNK / 4; i++) {
sum[n][y] = __builtin_amdgcn_dot4_f32_fp8_fp8(
bigA[n][k2].i[i], bigB[y][k2].i[i], sum[n][y]);
}
}
#else
// gfx9: MFMA path
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
@@ -2009,11 +2146,33 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
0);
}
}
#endif
}
}
}
// Final reduction
#ifdef __HIP__GFX12__
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:1 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
sum[n][y] += __shfl_xor(sum[n][y], 16);
}
}
#else
// gfx9 MFMA reduction
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
@@ -2028,8 +2187,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
sum[n][y][0] = accm0;
}
}
#endif
if (threadIdx.x == 0) {
const bool writeback_lane =
#ifdef __HIP__GFX12__
threadIdx.x == (THRDS - 1);
#else
threadIdx.x == 0;
#endif
if (writeback_lane) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
@@ -2040,13 +2206,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
sum[n][y][0] *= sA * sB;
#ifdef __HIP__GFX12__
float result = sum[n][y] * sA * sB;
#else
float result = sum[n][y][0] * sA * sB;
#endif
if constexpr (std::is_same_v<scalar_t, half>) {
sum[n][y][0] += __half2float(biases[n][y]);
result += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
sum[n][y][0] += __bfloat162float(biases[n][y]);
result += __bfloat162float(biases[n][y]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
C[m + y + n * M] = __float2s<scalar_t>(result);
}
}
}
@@ -2054,7 +2224,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m += CuCount * _WvPrGrp * YTILE;
}
}
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__)
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
@@ -2066,7 +2236,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
const int CuCount) {
UNREACHABLE_CODE
}
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
@@ -2099,24 +2269,30 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
#define WVSPLITKQ_IMPL(_THRDS, _WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
{ \
dim3 block(64, _WvPrGrp); \
dim3 block(_THRDS, _WvPrGrp); \
if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16)); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
wvSplitKQ_hf_sml_<fptype, fp8_t, _THRDS, _YTILEs, _WvPrGrp, 16, _UNRLs, \
_N><<<grid, block, 0, stream>>>( \
K_in, Kap_in, Kbp_in, M_in, Bx_in, By_in, b_ptr, a_ptr, bias_ptr, \
c_ptr, s_a, s_b, __wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16)); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
wvSplitKQ_hf_<fptype, fp8_t, _THRDS, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} \
}
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
if (on_gfx12()) \
WVSPLITKQ_IMPL(32, _WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
else \
WVSPLITKQ_IMPL(64, _WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N)
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] {
using fptype = typename scalar<scalar_t>::type;
auto c_ptr = reinterpret_cast<fptype*>(out_c.data_ptr());
@@ -2136,10 +2312,10 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
WVSPLITKQ(16, 2, 2, 2, 2, 2)
break;
case 3:
WVSPLITKQ(16, 2, 2, 2, 2, 3)
WVSPLITKQ(16, 2, 2, 1, 1, 3)
break;
case 4:
WVSPLITKQ(16, 2, 2, 2, 2, 4)
WVSPLITKQ(16, 2, 2, 1, 1, 4)
break;
default:
throw std::runtime_error(

View File

@@ -160,6 +160,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, padded_a, bias_mode
BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
elif bias_mode == 2:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
elif bias_mode == 3:
BIAS = torch.rand(1, m, dtype=dtype, device="cuda") * 2 - 1
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitKrc(A, B, cu_count, BIAS)
@@ -224,10 +226,9 @@ def test_rocm_wvsplitk_kernel(
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
# Accumulation error in fp16 GEMM scales with sqrt(K)
atol = torch.finfo(dtype).eps * math.sqrt(k)
torch.testing.assert_close(out, ref_out, atol=atol, rtol=1e-2)
@pytest.mark.parametrize("xnorm", [False, True])

View File

@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
import torch
from vllm.platforms import current_platform
if current_platform.is_cuda():
pytest.skip(
"ROCm skinny GEMM tests are not supported on CUDA.",
allow_module_level=True,
)
from vllm.model_executor.layers import utils
def test_rocm_unquantized_gemm_gfx1x_wvsplitk_path(monkeypatch):
x = torch.randn(1, 64, dtype=torch.float16)
weight = torch.randn(128, 64, dtype=torch.float16)
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False)
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t())
monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock)
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
ref = torch.nn.functional.linear(x, weight, None)
wvsplitk_mock.assert_called_once()
llmm1_mock.assert_not_called()
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
def test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back(monkeypatch):
x = torch.randn(5, 64, dtype=torch.float16)
weight = torch.randn(128, 64, dtype=torch.float16)
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False)
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t())
monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock)
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
ref = torch.nn.functional.linear(x, weight, None)
wvsplitk_mock.assert_not_called()
llmm1_mock.assert_not_called()
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
def test_rocm_unquantized_gemm_gfx950_wvsplitkrc_path(monkeypatch):
x = torch.randn(16, 1024, dtype=torch.float16)
weight = torch.randn(256, 1024, dtype=torch.float16)
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: False)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: True)
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
wvsplitkrc_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
monkeypatch.setattr(utils.ops, "wvSplitKrc", wvsplitkrc_mock)
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
ref = torch.nn.functional.linear(x, weight, None)
wvsplitkrc_mock.assert_called_once()
wvsplitk_mock.assert_not_called()
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)

View File

@@ -122,7 +122,7 @@ def use_aiter_triton_gemm(n, m, k, dtype):
def rocm_unquantized_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9, on_gfx950
from vllm.platforms.rocm import on_gfx1x, on_gfx9, on_gfx950
n = x.numel() // x.size(-1)
m = weight.shape[0]
@@ -169,12 +169,12 @@ def rocm_unquantized_gemm_impl(
use_skinny = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx9()
and (on_gfx9() or on_gfx1x())
and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0
)
if use_skinny is not True:
if not use_skinny:
return torch.nn.functional.linear(x, weight, bias)
x_view = x.reshape(-1, x.size(-1))