Improvements to wvSplitKrc skinny GEMM solution (#34304)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "../cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#include "quantization/w8a8/fp8/common.cuh"
|
||||
#include "core/batch_invariant.hpp"
|
||||
|
||||
// TODO(rasmith): The kernels in this file are susceptible to integer overflow
|
||||
// issues, do not take strides, and are unable to handle PyTorch tensors that
|
||||
@@ -1224,17 +1225,14 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
#if defined(__gfx950__)
|
||||
#define WVSPLITKRC_1KPASS
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N, int GrpsShrB, int CHUNKK>
|
||||
int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
__attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
wvSplitKrc_(const int actlN, const int K, const int M, const int Bx,
|
||||
const int By, const scalar_t* __restrict__ B,
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, float* glbl, scalar_t* C,
|
||||
const int CuCount) {
|
||||
// Use upper half of glbl buffer for atomic reduce counting
|
||||
int* cntr = (int*)(&glbl[M * N]);
|
||||
|
||||
wvSplitKrc_(const int actlN, const int K, const int Kap, const int M,
|
||||
const int Bx, const int By, const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
const scalar_t* __restrict__ BIAS, float* glbl, int* cntr,
|
||||
scalar_t* C, const int CuCount) {
|
||||
constexpr int NTILE = 16;
|
||||
constexpr int APAD = 1;
|
||||
constexpr int ASTRD = 64;
|
||||
@@ -1425,11 +1423,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff);
|
||||
for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) {
|
||||
__builtin_amdgcn_global_load_lds(
|
||||
(int*)(&A[min__(
|
||||
K * actlN - A_CHUNK,
|
||||
kOffcp + K * (n / CHUNKK +
|
||||
(N / CHUNKK) * (threadIdx.x / (64 / CHUNKK)) +
|
||||
(threadIdx.y % sprdN)))]),
|
||||
(int*)(&A[min__(Kap * actlN - A_CHUNK,
|
||||
kOffcp + Kap * (n / CHUNKK +
|
||||
(N / CHUNKK) * (threadIdx.x /
|
||||
(64 / CHUNKK)) +
|
||||
(threadIdx.y % sprdN)))]),
|
||||
(int*)(&s[(k +
|
||||
kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]),
|
||||
16, 0, 0);
|
||||
@@ -1533,45 +1531,98 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
}
|
||||
|
||||
union flt4 {
|
||||
scalar8 s8;
|
||||
float2 f2[2];
|
||||
float4 f4;
|
||||
};
|
||||
if (m + (threadIdx.x % 16) < M) {
|
||||
int my_cntr;
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction
|
||||
scalar_t biases[N / NTILE / GrpsShrB][4] = {};
|
||||
// Atomic add the output, read biases
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
// int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
// (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
// int adr = mindx + M * nindx;
|
||||
int g_nindx =
|
||||
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
|
||||
int g_adr = g_mindx + M * g_nindx * 4;
|
||||
atomicAdd(&glbl[g_adr], sum4[nt][0][j]);
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
int g_nindx =
|
||||
(nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
|
||||
int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
|
||||
if (DTRMNSTC) {
|
||||
flt4 flt4_ = {.s8 = sum4[nt][0]};
|
||||
__hip_atomic_store((float2*)&glbl[g_adr + M * N * (m0 / Mmod)],
|
||||
flt4_.f2[0], __ATOMIC_RELAXED,
|
||||
__HIP_MEMORY_SCOPE_AGENT);
|
||||
__hip_atomic_store((float2*)&glbl[g_adr + 2 + M * N * (m0 / Mmod)],
|
||||
flt4_.f2[1], __ATOMIC_RELAXED,
|
||||
__HIP_MEMORY_SCOPE_AGENT);
|
||||
} else {
|
||||
for (uint32_t j = 0; j < 4; j++)
|
||||
atomicAdd((&glbl[g_adr + j]), sum4[nt][0][j]);
|
||||
}
|
||||
}
|
||||
|
||||
__atomic_signal_fence(__ATOMIC_SEQ_CST);
|
||||
asm volatile("s_waitcnt vmcnt(0)" ::: "memory");
|
||||
__atomic_signal_fence(__ATOMIC_SEQ_CST);
|
||||
|
||||
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr_ = mindx + M * nindx_ / 4;
|
||||
// Update the complete counter
|
||||
my_cntr = atomicAdd(&cntr[adr_], 1);
|
||||
float vals[N / NTILE / GrpsShrB][4] = {};
|
||||
|
||||
// make sure LDS is free for write out staging
|
||||
if (DTRMNSTC) __syncthreads();
|
||||
|
||||
// Update the complete counter
|
||||
flt4 vals[N / NTILE / GrpsShrB] = {};
|
||||
// If we're the last k-shard, read back the value and convert...
|
||||
if (my_cntr + 1 == k_rnd) {
|
||||
if (BIAS)
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
|
||||
cntr[adr_] = 0; // clear for next round
|
||||
if constexpr (DTRMNSTC) {
|
||||
#pragma unroll
|
||||
for (int ks = 0; ks < k_rnd; ks++) {
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
int g_nindx =
|
||||
(nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
|
||||
int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
|
||||
__builtin_amdgcn_global_load_lds(
|
||||
(float4*)(&glbl[g_adr + M * N * ks]),
|
||||
&(((float4*)s)[(threadIdx.y * THRDS) + ks * THRDS * 4 +
|
||||
nt * THRDS * 4 * k_rnd]),
|
||||
16, 0, 0);
|
||||
}
|
||||
}
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int g_nindx =
|
||||
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
|
||||
int g_adr = g_mindx + M * g_nindx * 4;
|
||||
vals[nt][j] = glbl[g_adr];
|
||||
if (BIAS)
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
|
||||
}
|
||||
}
|
||||
asm volatile("s_waitcnt 0");
|
||||
for (int ks = 0; ks < k_rnd; ks++) {
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
float4 eval = ((float4*)s)[(threadIdx.x + threadIdx.y * THRDS) +
|
||||
ks * THRDS * 4 + nt * THRDS * 4 * k_rnd];
|
||||
vals[nt].f4 += eval;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
int g_nindx =
|
||||
(nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
|
||||
int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4;
|
||||
vals[nt].f4 = *(float4*)(&glbl[g_adr]);
|
||||
*(float4*)(&glbl[g_adr]) = {}; // clear out for next round
|
||||
}
|
||||
if (BIAS)
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
|
||||
}
|
||||
}
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
@@ -1581,11 +1632,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (nindx < actlN) {
|
||||
int adr = mindx + M * nindx;
|
||||
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
vals[nt][j] += __bfloat162float(biases[nt][j]);
|
||||
C[adr] = __float2bfloat16(vals[nt][j]);
|
||||
vals[nt].s8[j] += __bfloat162float(biases[nt][j]);
|
||||
C[adr] = __float2bfloat16(vals[nt].s8[j]);
|
||||
} else {
|
||||
vals[nt][j] += __half2float(biases[nt][j]);
|
||||
C[adr] = __float2half(vals[nt][j]);
|
||||
vals[nt].s8[j] += __half2float(biases[nt][j]);
|
||||
C[adr] = __float2half(vals[nt].s8[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1604,21 +1655,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N, int GrpsShrB, int CHUNKK>
|
||||
__global__ void wvSplitKrc_(const int actlN, const int K, const int M,
|
||||
const int Bx, const int By, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A,
|
||||
int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
|
||||
__global__ void wvSplitKrc_(const int actlN, const int K, const int Kap,
|
||||
const int M, const int Bx, const int By,
|
||||
const scalar_t* B, const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, float* glbl,
|
||||
// int* cntr,
|
||||
scalar_t* C, const int CuCount){UNREACHABLE_CODE}
|
||||
int* cntr, scalar_t* C,
|
||||
const int CuCount){UNREACHABLE_CODE}
|
||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
|
||||
torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const std::optional<at::Tensor>& in_bias,
|
||||
const int64_t CuCount) {
|
||||
auto M_in = in_a.size(0);
|
||||
auto N_in = in_b.size(0);
|
||||
auto K_in = in_a.size(1);
|
||||
int _DTRMNSTC = 1; // vllm::vllm_is_batch_invariant();
|
||||
|
||||
auto M_in = in_b.size(0);
|
||||
auto N_in = in_a.size(0);
|
||||
auto K_in = in_b.size(1);
|
||||
auto Kap_in = in_a.stride(0);
|
||||
|
||||
auto Bx_in =
|
||||
(in_bias.has_value() && in_bias->numel() > 0)
|
||||
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
|
||||
@@ -1635,13 +1690,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
|
||||
auto out_c = torch::empty(
|
||||
{N_in, M_in},
|
||||
torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));
|
||||
torch::TensorOptions().dtype(in_a.dtype()).device(in_a.device()));
|
||||
|
||||
auto N_p2 = 1U << (32 - __builtin_clz(N_in - 1));
|
||||
auto axl_glbl = torch::empty(
|
||||
{N_p2 + N_p2 / 4, M_in + M_in / 4},
|
||||
torch::TensorOptions().dtype(torch::kFloat32).device(in_b.device()));
|
||||
axl_glbl.zero_(); // disable for FAST_UNSAFE_RDC_INIT
|
||||
|
||||
dim3 grid(CuCount);
|
||||
|
||||
@@ -1649,55 +1700,70 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// const int max_lds_len = get_lds_size() / 2;
|
||||
|
||||
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
|
||||
// and each working on a 512-shard of K, how many CUs would we need?
|
||||
int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);
|
||||
|
||||
// How many of 4 waves in a group can work on same 16 Ms at same time? First
|
||||
// try to maximize this. This reduces the Ms each group works on, i.e.
|
||||
// increasing the number of CUs needed.
|
||||
int GrpsShrB = min(N_p2 / 16, 4);
|
||||
|
||||
// Given the above, how many CUs would we need?
|
||||
int CuNeeded = rndup_cus * GrpsShrB;
|
||||
|
||||
if (CuNeeded > CuCount) throw std::runtime_error("Invalid wvSplitKrc size");
|
||||
|
||||
// Can we increase SplitK by shrinking the K-shared to 256?
|
||||
int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;
|
||||
|
||||
static torch::Tensor axl_glbl =
|
||||
torch::zeros(
|
||||
128 * 1024 * (_DTRMNSTC ? 12 : 1),
|
||||
torch::TensorOptions().dtype(torch::kFloat32).device(in_a.device()))
|
||||
.detach();
|
||||
static torch::Tensor axl_cntr =
|
||||
torch::zeros(
|
||||
128 * 1024 * (_DTRMNSTC ? 12 : 1) / 4,
|
||||
torch::TensorOptions().dtype(torch::kInt).device(in_a.device()))
|
||||
.detach();
|
||||
auto glbl = axl_glbl.data_ptr<float>();
|
||||
auto cntr = axl_cntr.data_ptr<int>();
|
||||
|
||||
#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \
|
||||
{ \
|
||||
dim3 block(64, 4); \
|
||||
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK> \
|
||||
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, glbl, c, CuCount); \
|
||||
if (_DTRMNSTC) \
|
||||
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK, 1> \
|
||||
<<<grid, block, 0, stream>>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \
|
||||
af4, bf4, biasf4, glbl, cntr, c, \
|
||||
CuCount); \
|
||||
else \
|
||||
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK, 0> \
|
||||
<<<grid, block, 0, stream>>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \
|
||||
af4, bf4, biasf4, glbl, cntr, c, \
|
||||
CuCount); \
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitKrc", [&] {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_a.scalar_type(), "wvSplitKrc", [&] {
|
||||
using fptype = typename scalar<scalar_t>::type;
|
||||
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
|
||||
const fptype* af4 = reinterpret_cast<const fptype*>(in_a.data_ptr());
|
||||
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
|
||||
const fptype* biasf4 =
|
||||
(in_bias.has_value() && in_bias->numel() > 0)
|
||||
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
|
||||
: nullptr;
|
||||
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
||||
auto glbl = axl_glbl.data_ptr<float>();
|
||||
|
||||
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
|
||||
// and each working on a 512-shard of K, how many CUs would we need?
|
||||
int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);
|
||||
|
||||
// How many of 4 waves in a group can work on same 16 Ms at same time? First
|
||||
// try to maximize this. This reduces the Ms each group works on, i.e.
|
||||
// increasing the number of CUs needed.
|
||||
int GrpsShrB = min(N_p2 / 16, 4);
|
||||
|
||||
// Given the above, how many CUs would we need?
|
||||
int CuNeeded = rndup_cus * GrpsShrB;
|
||||
|
||||
if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size");
|
||||
|
||||
// Can we increase SplitK by shrinking the K-shared to 256?
|
||||
int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;
|
||||
|
||||
switch (N_p2) {
|
||||
case 16:
|
||||
WVSPLITKrc(16, 1, 1) break;
|
||||
case 32:
|
||||
if (chunkk == 2)
|
||||
WVSPLITKrc(32, 2, 2) else if (chunkk == 1) WVSPLITKrc(32, 2, 1) break;
|
||||
if (chunkk == 2) WVSPLITKrc(32, 2, 2) else WVSPLITKrc(32, 2, 1) break;
|
||||
case 64:
|
||||
if (chunkk == 2)
|
||||
WVSPLITKrc(64, 4, 2) else if (chunkk == 1) WVSPLITKrc(64, 4, 1) break;
|
||||
if (chunkk == 2) WVSPLITKrc(64, 4, 2) else WVSPLITKrc(64, 4, 1) break;
|
||||
case 128:
|
||||
if (chunkk == 2)
|
||||
WVSPLITKrc(128, 4, 2) else if (chunkk == 1)
|
||||
WVSPLITKrc(128, 4, 1) break;
|
||||
if (chunkk == 2) WVSPLITKrc(128, 4, 2) else WVSPLITKrc(128, 4, 1) break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Unsupported N value: " + std::to_string(M_in) + "," +
|
||||
|
||||
Reference in New Issue
Block a user