Improvements to wvSplitKrc skinny GEMM solution (#34304)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
Hashem Hashemi
2026-03-10 09:14:27 -07:00
committed by GitHub
parent aefc59f088
commit 721ae79f50
3 changed files with 169 additions and 98 deletions

View File

@@ -12,6 +12,7 @@
#include "../cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/w8a8/fp8/common.cuh"
#include "core/batch_invariant.hpp"
// TODO(rasmith): The kernels in this file are susceptible to integer overflow
// issues, do not take strides, and are unable to handle PyTorch tensors that
@@ -1224,17 +1225,14 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
#if defined(__gfx950__)
#define WVSPLITKRC_1KPASS
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N, int GrpsShrB, int CHUNKK>
int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
__attribute__((amdgpu_waves_per_eu(1, 1)))
wvSplitKrc_(const int actlN, const int K, const int M, const int Bx,
const int By, const scalar_t* __restrict__ B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, float* glbl, scalar_t* C,
const int CuCount) {
// Use upper half of glbl buffer for atomic reduce counting
int* cntr = (int*)(&glbl[M * N]);
wvSplitKrc_(const int actlN, const int K, const int Kap, const int M,
const int Bx, const int By, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
const scalar_t* __restrict__ BIAS, float* glbl, int* cntr,
scalar_t* C, const int CuCount) {
constexpr int NTILE = 16;
constexpr int APAD = 1;
constexpr int ASTRD = 64;
@@ -1425,11 +1423,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff);
for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) {
__builtin_amdgcn_global_load_lds(
(int*)(&A[min__(
K * actlN - A_CHUNK,
kOffcp + K * (n / CHUNKK +
(N / CHUNKK) * (threadIdx.x / (64 / CHUNKK)) +
(threadIdx.y % sprdN)))]),
(int*)(&A[min__(Kap * actlN - A_CHUNK,
kOffcp + Kap * (n / CHUNKK +
(N / CHUNKK) * (threadIdx.x /
(64 / CHUNKK)) +
(threadIdx.y % sprdN)))]),
(int*)(&s[(k +
kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]),
16, 0, 0);
@@ -1533,45 +1531,98 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
union flt4 {
scalar8 s8;
float2 f2[2];
float4 f4;
};
if (m + (threadIdx.x % 16) < M) {
int my_cntr;
int mindx = m + (threadIdx.x % 16);
int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction
scalar_t biases[N / NTILE / GrpsShrB][4] = {};
// Atomic add the output, read biases
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
for (uint32_t j = 0; j < 4; j++) {
// int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
// (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
// int adr = mindx + M * nindx;
int g_nindx =
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx + M * g_nindx * 4;
atomicAdd(&glbl[g_adr], sum4[nt][0][j]);
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
int g_nindx =
(nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
if (DTRMNSTC) {
flt4 flt4_ = {.s8 = sum4[nt][0]};
__hip_atomic_store((float2*)&glbl[g_adr + M * N * (m0 / Mmod)],
flt4_.f2[0], __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_AGENT);
__hip_atomic_store((float2*)&glbl[g_adr + 2 + M * N * (m0 / Mmod)],
flt4_.f2[1], __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_AGENT);
} else {
for (uint32_t j = 0; j < 4; j++)
atomicAdd((&glbl[g_adr + j]), sum4[nt][0][j]);
}
}
__atomic_signal_fence(__ATOMIC_SEQ_CST);
asm volatile("s_waitcnt vmcnt(0)" ::: "memory");
__atomic_signal_fence(__ATOMIC_SEQ_CST);
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
int adr_ = mindx + M * nindx_ / 4;
// Update the complete counter
my_cntr = atomicAdd(&cntr[adr_], 1);
float vals[N / NTILE / GrpsShrB][4] = {};
// make sure LDS is free for write out staging
if (DTRMNSTC) __syncthreads();
// Update the complete counter
flt4 vals[N / NTILE / GrpsShrB] = {};
// If we're the last k-shard, read back the value and convert...
if (my_cntr + 1 == k_rnd) {
if (BIAS)
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
cntr[adr_] = 0; // clear for next round
if constexpr (DTRMNSTC) {
#pragma unroll
for (int ks = 0; ks < k_rnd; ks++) {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
int g_nindx =
(nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
__builtin_amdgcn_global_load_lds(
(float4*)(&glbl[g_adr + M * N * ks]),
&(((float4*)s)[(threadIdx.y * THRDS) + ks * THRDS * 4 +
nt * THRDS * 4 * k_rnd]),
16, 0, 0);
}
}
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int g_nindx =
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx + M * g_nindx * 4;
vals[nt][j] = glbl[g_adr];
if (BIAS)
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
}
}
asm volatile("s_waitcnt 0");
for (int ks = 0; ks < k_rnd; ks++) {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
float4 eval = ((float4*)s)[(threadIdx.x + threadIdx.y * THRDS) +
ks * THRDS * 4 + nt * THRDS * 4 * k_rnd];
vals[nt].f4 += eval;
}
}
} else {
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
int g_nindx =
(nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
vals[nt].f4 = *(float4*)(&glbl[g_adr]);
*(float4*)(&glbl[g_adr]) = {}; // clear out for next round
}
if (BIAS)
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
for (uint32_t j = 0; j < 4; j++) {
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
}
}
}
__builtin_amdgcn_sched_barrier(0);
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
@@ -1581,11 +1632,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (nindx < actlN) {
int adr = mindx + M * nindx;
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
vals[nt][j] += __bfloat162float(biases[nt][j]);
C[adr] = __float2bfloat16(vals[nt][j]);
vals[nt].s8[j] += __bfloat162float(biases[nt][j]);
C[adr] = __float2bfloat16(vals[nt].s8[j]);
} else {
vals[nt][j] += __half2float(biases[nt][j]);
C[adr] = __float2half(vals[nt][j]);
vals[nt].s8[j] += __half2float(biases[nt][j]);
C[adr] = __float2half(vals[nt].s8[j]);
}
}
}
@@ -1604,21 +1655,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N, int GrpsShrB, int CHUNKK>
__global__ void wvSplitKrc_(const int actlN, const int K, const int M,
const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
__global__ void wvSplitKrc_(const int actlN, const int K, const int Kap,
const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, float* glbl,
// int* cntr,
scalar_t* C, const int CuCount){UNREACHABLE_CODE}
int* cntr, scalar_t* C,
const int CuCount){UNREACHABLE_CODE}
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount) {
auto M_in = in_a.size(0);
auto N_in = in_b.size(0);
auto K_in = in_a.size(1);
int _DTRMNSTC = 1; // vllm::vllm_is_batch_invariant();
auto M_in = in_b.size(0);
auto N_in = in_a.size(0);
auto K_in = in_b.size(1);
auto Kap_in = in_a.stride(0);
auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
@@ -1635,13 +1690,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
auto out_c = torch::empty(
{N_in, M_in},
torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));
torch::TensorOptions().dtype(in_a.dtype()).device(in_a.device()));
auto N_p2 = 1U << (32 - __builtin_clz(N_in - 1));
auto axl_glbl = torch::empty(
{N_p2 + N_p2 / 4, M_in + M_in / 4},
torch::TensorOptions().dtype(torch::kFloat32).device(in_b.device()));
axl_glbl.zero_(); // disable for FAST_UNSAFE_RDC_INIT
dim3 grid(CuCount);
@@ -1649,55 +1700,70 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// const int max_lds_len = get_lds_size() / 2;
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
// and each working on a 512-shard of K, how many CUs would we need?
int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);
// How many of 4 waves in a group can work on same 16 Ms at same time? First
// try to maximize this. This reduces the Ms each group works on, i.e.
// increasing the number of CUs needed.
int GrpsShrB = min(N_p2 / 16, 4);
// Given the above, how many CUs would we need?
int CuNeeded = rndup_cus * GrpsShrB;
if (CuNeeded > CuCount) throw std::runtime_error("Invalid wvSplitKrc size");
// Can we increase SplitK by shrinking the K-shared to 256?
int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;
static torch::Tensor axl_glbl =
torch::zeros(
128 * 1024 * (_DTRMNSTC ? 12 : 1),
torch::TensorOptions().dtype(torch::kFloat32).device(in_a.device()))
.detach();
static torch::Tensor axl_cntr =
torch::zeros(
128 * 1024 * (_DTRMNSTC ? 12 : 1) / 4,
torch::TensorOptions().dtype(torch::kInt).device(in_a.device()))
.detach();
auto glbl = axl_glbl.data_ptr<float>();
auto cntr = axl_cntr.data_ptr<int>();
#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \
{ \
dim3 block(64, 4); \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK> \
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, glbl, c, CuCount); \
if (_DTRMNSTC) \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK, 1> \
<<<grid, block, 0, stream>>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \
af4, bf4, biasf4, glbl, cntr, c, \
CuCount); \
else \
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK, 0> \
<<<grid, block, 0, stream>>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \
af4, bf4, biasf4, glbl, cntr, c, \
CuCount); \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitKrc", [&] {
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_a.scalar_type(), "wvSplitKrc", [&] {
using fptype = typename scalar<scalar_t>::type;
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
const fptype* af4 = reinterpret_cast<const fptype*>(in_a.data_ptr());
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
const fptype* biasf4 =
(in_bias.has_value() && in_bias->numel() > 0)
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
: nullptr;
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
auto glbl = axl_glbl.data_ptr<float>();
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
// and each working on a 512-shard of K, how many CUs would we need?
int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);
// How many of 4 waves in a group can work on same 16 Ms at same time? First
// try to maximize this. This reduces the Ms each group works on, i.e.
// increasing the number of CUs needed.
int GrpsShrB = min(N_p2 / 16, 4);
// Given the above, how many CUs would we need?
int CuNeeded = rndup_cus * GrpsShrB;
if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size");
// Can we increase SplitK by shrinking the K-shared to 256?
int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;
switch (N_p2) {
case 16:
WVSPLITKrc(16, 1, 1) break;
case 32:
if (chunkk == 2)
WVSPLITKrc(32, 2, 2) else if (chunkk == 1) WVSPLITKrc(32, 2, 1) break;
if (chunkk == 2) WVSPLITKrc(32, 2, 2) else WVSPLITKrc(32, 2, 1) break;
case 64:
if (chunkk == 2)
WVSPLITKrc(64, 4, 2) else if (chunkk == 1) WVSPLITKrc(64, 4, 1) break;
if (chunkk == 2) WVSPLITKrc(64, 4, 2) else WVSPLITKrc(64, 4, 1) break;
case 128:
if (chunkk == 2)
WVSPLITKrc(128, 4, 2) else if (chunkk == 1)
WVSPLITKrc(128, 4, 1) break;
if (chunkk == 2) WVSPLITKrc(128, 4, 2) else WVSPLITKrc(128, 4, 1) break;
default:
throw std::runtime_error(
"Unsupported N value: " + std::to_string(M_in) + "," +