Perf tuning and expansion of cases covered for wvSplitKrc (#33493)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
@@ -1365,13 +1365,12 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
return out_c;
|
||||
}
|
||||
|
||||
#if defined(__gfx950__) // TODO: Add NAVI support
|
||||
// This version targets big A[] cases, where it is much larger than LDS
|
||||
// capacity
|
||||
// This version targets cases skinny where CUs are not filled
|
||||
// Wave-SplitK is used with reduction done via atomics.
|
||||
#if defined(__gfx950__)
|
||||
#define WVSPLITKRC_1KPASS
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N, int GrpsShrB>
|
||||
|
||||
int UNRL, int N, int GrpsShrB, int CHUNKK>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
__attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
wvSplitKrc_(const int actlN, const int K, const int M, const int Bx,
|
||||
@@ -1383,12 +1382,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
int* cntr = (int*)(&glbl[M * N]);
|
||||
|
||||
constexpr int NTILE = 16;
|
||||
constexpr int WVLDS_ = (NTILE * THRDS * A_CHUNK);
|
||||
constexpr int APAD = 1;
|
||||
constexpr int ASTRD = 64;
|
||||
constexpr int BPAD = 1;
|
||||
constexpr int BSTRD = 64;
|
||||
constexpr int WVLDS = ((WVLDS_ + (WVLDS_ / BSTRD) * 4 * BPAD));
|
||||
constexpr int WVLDS_ = THRDS * A_CHUNK / CHUNKK;
|
||||
constexpr int WVLDS = ((WVLDS_ + A_CHUNK * BPAD)) * YTILE;
|
||||
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
|
||||
@@ -1442,17 +1440,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
break;
|
||||
}
|
||||
#else
|
||||
int constexpr kFit = 512;
|
||||
int constexpr kFit = 512 / CHUNKK;
|
||||
int constexpr kfitsPerRdc = 1;
|
||||
#endif
|
||||
|
||||
bool doRdc = (kfitsPerRdc * kFit < K);
|
||||
bool doRdc = true; // Assuming (kfitsPerRdc * kFit < K) is always true
|
||||
uint32_t numCuWithFullK =
|
||||
((M + (WvPrGrp * YTILE / GrpsShrB) - 1) / (WvPrGrp * YTILE / GrpsShrB));
|
||||
uint32_t Mmod = numCuWithFullK * (WvPrGrp * YTILE / GrpsShrB);
|
||||
|
||||
// given above k-split, find this wave's position
|
||||
uint32_t kFitPdd = kFit + (kFit / ASTRD) * APAD;
|
||||
uint32_t kFitPdd = kFit * CHUNKK + ((kFit * CHUNKK) / ASTRD) * APAD;
|
||||
uint32_t m0 = (blockIdx.x * WvPrGrp / GrpsShrB) * YTILE;
|
||||
uint32_t m1 = ((threadIdx.y % WvPrGrp) / GrpsShrB) * YTILE;
|
||||
uint32_t m = (m0 + m1) % Mmod;
|
||||
@@ -1460,8 +1458,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc;
|
||||
const uint32_t k_rnd = (K + kFit * kfitsPerRdc - 1) / (kFit * kfitsPerRdc);
|
||||
|
||||
scalar8 sum4[N / NTILE / GrpsShrB][1];
|
||||
bigType bigB_[YTILE / GrpsShrB][UNRL];
|
||||
scalar8 sum4[N / NTILE / GrpsShrB][1] = {0};
|
||||
bigType bigB_[YTILE / GrpsShrB / CHUNKK][UNRL];
|
||||
const uint32_t bLoader = (threadIdx.y % GrpsShrB);
|
||||
uint32_t kBase = 0;
|
||||
if (k_str >= K) return;
|
||||
@@ -1498,12 +1496,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k_str + k2 * THRDS * A_CHUNK;
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK;
|
||||
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||
#pragma unroll
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++)
|
||||
bigB_[y][k2].h8 = (loadnt(
|
||||
(scalar8*)(&B_[min__(y * GrpsShrB + bLoader + m, M - 1) * K])));
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK)
|
||||
bigB_[y / CHUNKK][k2].h8 = (loadnt(
|
||||
(scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB +
|
||||
bLoader + m,
|
||||
M - 1) *
|
||||
K])));
|
||||
}
|
||||
{
|
||||
#else
|
||||
@@ -1556,48 +1557,51 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (reloada) {
|
||||
#endif
|
||||
constexpr int sprdN = 4;
|
||||
const uint32_t thrd = ((threadIdx.y / sprdN) * THRDS + threadIdx.x);
|
||||
const uint32_t thrd = threadIdx.x % (THRDS / CHUNKK);
|
||||
|
||||
#ifndef WVSPLITKRC_1KPASS
|
||||
#pragma unroll
|
||||
for (int k = 0; k < kFit; k += THRDS * (WvPrGrp / sprdN) * A_CHUNK) {
|
||||
for (int k = 0; k < kFit;
|
||||
k += (THRDS * (WvPrGrp / sprdN) * A_CHUNK) / CHUNKK) {
|
||||
#else
|
||||
const unsigned int k = 0;
|
||||
{
|
||||
#endif
|
||||
unsigned int kOff = k + (thrd * A_CHUNK);
|
||||
unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff);
|
||||
const unsigned int k_in = kOffcp + ((threadIdx.y % sprdN)) * K;
|
||||
const unsigned int k_ot = kOff + ((threadIdx.y % sprdN)) * kFitPdd;
|
||||
for (unsigned int n = 0; n < N / 2; n += sprdN) {
|
||||
__builtin_amdgcn_global_load_lds((int*)(&A[k_in + n * K]),
|
||||
(int*)(&s[(k_ot + n * kFitPdd)]),
|
||||
16, 0, 0);
|
||||
if (((threadIdx.y % sprdN)) + n + N / 2 >= actlN) continue;
|
||||
unsigned int kOffcp =
|
||||
k_str + kOff; // min__(K - A_CHUNK, k_str + kOff);
|
||||
for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) {
|
||||
__builtin_amdgcn_global_load_lds(
|
||||
(int*)(&A[k_in + (n + N / 2) * K]),
|
||||
(int*)(&s[(k_ot + (n + N / 2) * kFitPdd)]), 16, 0, 0);
|
||||
(int*)(&A[min__(
|
||||
K * actlN - A_CHUNK,
|
||||
kOffcp + K * (n / CHUNKK +
|
||||
(N / CHUNKK) * (threadIdx.x / (64 / CHUNKK)) +
|
||||
(threadIdx.y % sprdN)))]),
|
||||
(int*)(&s[(k +
|
||||
kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]),
|
||||
16, 0, 0);
|
||||
}
|
||||
|
||||
// Stage loaded B[] to LDS for MFMA swizzling...
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK;
|
||||
const bool oob_k = (k_ >= K);
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++) {
|
||||
uint32_t idx = threadIdx.x * 4 +
|
||||
(y * GrpsShrB + bLoader) * ((THRDS + BPAD) * 4);
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) {
|
||||
uint32_t idx =
|
||||
(threadIdx.x % (THRDS / CHUNKK)) * 4 +
|
||||
((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB + bLoader) *
|
||||
((THRDS / CHUNKK + BPAD) * 4);
|
||||
// zero out if oob
|
||||
*((scalar8*)&myStg[idx]) =
|
||||
(oob_k || (y * GrpsShrB + bLoader + m >= M))
|
||||
(oob_k) // TODO: ever necessary (y*GrpsShrB+bLoader+m>=M) ?
|
||||
? 0
|
||||
: bigB_[y][k2].h8;
|
||||
: bigB_[y / CHUNKK][k2].h8;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef WVSPLITKRC_1KPASS
|
||||
// Fire load of next B[] chunk...
|
||||
if ((k1 + THRDS * A_CHUNK * UNRL < k_end) &&
|
||||
@@ -1608,40 +1612,50 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||
#pragma unroll
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++)
|
||||
bigB_[y][k2].h8 = (loadnt(
|
||||
(scalar8*)(&B_[min__(y * GrpsShrB + bLoader + m, M - 1) * K])));
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK)
|
||||
bigB_[y / CHUNKK][k2].h8 = (loadnt(
|
||||
(scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) *
|
||||
GrpsShrB +
|
||||
bLoader + m,
|
||||
M - 1) *
|
||||
K])));
|
||||
}
|
||||
#endif
|
||||
|
||||
// B[] staging is cooperative across GrpsShrB, so sync here before reading
|
||||
// back
|
||||
// back. This wait is currently inserted by compiler, but not gauranteed.
|
||||
asm volatile("s_waitcnt 0");
|
||||
__syncthreads();
|
||||
|
||||
// read back B[] swizzled for MFMA...
|
||||
bigType bigB[YTILE][UNRL];
|
||||
bigType bigB[YTILE / CHUNKK][UNRL];
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
for (uint32_t y = 0; y < YTILE; y++) {
|
||||
unsigned int idx = (threadIdx.x % YTILE) * ((THRDS + BPAD) * 4) +
|
||||
(threadIdx.x / YTILE) * 4 + y * 16;
|
||||
for (uint32_t y = 0; y < YTILE / CHUNKK; y++) {
|
||||
unsigned int idx =
|
||||
(threadIdx.x % YTILE) * ((THRDS / CHUNKK + BPAD) * 4) +
|
||||
(threadIdx.x / YTILE) * 4 + y * 16;
|
||||
bigB[y][k2].h8 = *((scalar8*)&myStg[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
// rReadback A[] swizzled for MFMA...
|
||||
bigType bigA[N / GrpsShrB][UNRL];
|
||||
bigType bigA[N / GrpsShrB / CHUNKK][UNRL];
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK - kBase - k_str;
|
||||
#pragma unroll
|
||||
for (uint32_t nt = 0; nt < N / GrpsShrB; nt += NTILE)
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < NTILE; n++) {
|
||||
uint32_t idxa = (nt + (threadIdx.x % NTILE) +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB)) *
|
||||
kFitPdd +
|
||||
A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k;
|
||||
bigA[nt + n][k2] = *((const bigType*)(&(s[idxa])));
|
||||
for (uint32_t n = 0; n < NTILE / CHUNKK; n++) {
|
||||
uint32_t idxa =
|
||||
((nt + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) % (N / CHUNKK) +
|
||||
(threadIdx.x % NTILE)) *
|
||||
kFitPdd +
|
||||
((nt + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) /
|
||||
(N / CHUNKK)) *
|
||||
A_CHUNK * (64 / CHUNKK) +
|
||||
A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k;
|
||||
bigA[nt / CHUNKK + n][k2] = *((const bigType*)(&(s[idxa])));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1650,152 +1664,75 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
#pragma unroll
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
bigA[nt * NTILE + 0][k2].h4[0], bigB[0][k2].h4[0],
|
||||
(k1 == k_str) ? ((scalar8){0}) : sum4[nt][0], 0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
bigA[nt * NTILE + 0][k2].h4[1], bigB[0][k2].h4[1], sum4[nt][0], 0,
|
||||
0, 0);
|
||||
} else { // bf16
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
bigA[nt * NTILE + 0][k2].h4[0], bigB[0][k2].h4[0],
|
||||
(k1 == k_str) ? ((scalar8){0}) : sum4[nt][0], 0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
bigA[nt * NTILE + 0][k2].h4[1], bigB[0][k2].h4[1], sum4[nt][0], 0,
|
||||
0, 0);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t j = 1; j < YTILE; j++) {
|
||||
for (uint32_t j = 0; j < YTILE / CHUNKK; j++) {
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
bigA[nt * NTILE + j][k2].h4[0], bigB[j][k2].h4[0], sum4[nt][0],
|
||||
0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
bigA[nt * NTILE + j][k2].h4[1], bigB[j][k2].h4[1], sum4[nt][0],
|
||||
0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x32_f16(
|
||||
bigA[nt * (YTILE / CHUNKK) + j][k2].h8, bigB[j][k2].h8,
|
||||
sum4[nt][0], 0, 0, 0);
|
||||
} else { // bf16
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
bigA[nt * NTILE + j][k2].h4[0], bigB[j][k2].h4[0], sum4[nt][0],
|
||||
0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
bigA[nt * NTILE + j][k2].h4[1], bigB[j][k2].h4[1], sum4[nt][0],
|
||||
0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
|
||||
bigA[nt * (YTILE / CHUNKK) + j][k2].h8, bigB[j][k2].h8,
|
||||
sum4[nt][0], 0, 0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!doRdc) {
|
||||
if (m + (threadIdx.x % 16) < M) {
|
||||
scalar_t biases[N / NTILE / GrpsShrB][4] = {0};
|
||||
if (m + (threadIdx.x % 16) < M) {
|
||||
int my_cntr;
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction
|
||||
scalar_t biases[N / NTILE / GrpsShrB][4] = {};
|
||||
// Atomic add the output, read biases
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
// int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
// (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
// int adr = mindx + M * nindx;
|
||||
int g_nindx =
|
||||
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
|
||||
int g_adr = g_mindx + M * g_nindx * 4;
|
||||
atomicAdd(&glbl[g_adr], sum4[nt][0][j]);
|
||||
}
|
||||
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr_ = mindx + M * nindx_ / 4;
|
||||
// Update the complete counter
|
||||
my_cntr = atomicAdd(&cntr[adr_], 1);
|
||||
float vals[N / NTILE / GrpsShrB][4] = {};
|
||||
// If we're the last k-shard, read back the value and convert...
|
||||
if (my_cntr + 1 == k_rnd) {
|
||||
if (BIAS)
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * M];
|
||||
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
|
||||
}
|
||||
}
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr = mindx + M * nindx;
|
||||
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS) sum4[nt][0][j] += __bfloat162float(biases[nt][j]);
|
||||
C[adr] = __float2bfloat16(sum4[nt][0][j]);
|
||||
} else {
|
||||
if (BIAS) sum4[nt][0][j] += __half2float(biases[nt][j]);
|
||||
C[adr] = __float2half(sum4[nt][0][j]);
|
||||
}
|
||||
int g_nindx =
|
||||
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
|
||||
int g_adr = g_mindx + M * g_nindx * 4;
|
||||
vals[nt][j] = glbl[g_adr];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (m + (threadIdx.x % 16) < M) {
|
||||
int my_cntr;
|
||||
if (!BIAS) {
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
if (nindx < actlN) {
|
||||
int adr = mindx + M * nindx;
|
||||
atomicAdd(&glbl[adr], sum4[nt][0][j]);
|
||||
}
|
||||
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr_ = mindx + M * nindx_ / 4;
|
||||
my_cntr = atomicAdd(&cntr[adr_], 1);
|
||||
float vals[N / NTILE / GrpsShrB][4] = {};
|
||||
if (my_cntr + 1 == k_rnd) {
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr = mindx + M * nindx;
|
||||
vals[nt][j] = glbl[adr];
|
||||
}
|
||||
}
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
if (nindx >= actlN) break;
|
||||
int adr = mindx + M * nindx;
|
||||
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
C[adr] = __float2bfloat16(vals[nt][j]);
|
||||
} else {
|
||||
C[adr] = __float2half(vals[nt][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
scalar_t biases[N / NTILE / GrpsShrB][4] = {};
|
||||
// Atomic add the output, read biases
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr = mindx + M * nindx;
|
||||
atomicAdd(&glbl[adr], sum4[nt][0][j]);
|
||||
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * M];
|
||||
}
|
||||
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr_ = mindx + M * nindx_ / 4;
|
||||
// Update the complete counter
|
||||
my_cntr = atomicAdd(&cntr[adr_], 1);
|
||||
float vals[N / NTILE / GrpsShrB][4] = {};
|
||||
// If we're the last k-shard, read back the value and convert...
|
||||
if (my_cntr + 1 == k_rnd) {
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr = mindx + M * nindx;
|
||||
vals[nt][j] = glbl[adr];
|
||||
}
|
||||
}
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
if (nindx >= actlN) break;
|
||||
int adr = mindx + M * nindx;
|
||||
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
vals[nt][j] += __bfloat162float(biases[nt][j]);
|
||||
C[adr] = __float2bfloat16(vals[nt][j]);
|
||||
} else {
|
||||
vals[nt][j] += __half2float(biases[nt][j]);
|
||||
C[adr] = __float2half(vals[nt][j]);
|
||||
}
|
||||
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
vals[nt][j] += __bfloat162float(biases[nt][j]);
|
||||
C[adr] = __float2bfloat16(vals[nt][j]);
|
||||
} else {
|
||||
vals[nt][j] += __half2float(biases[nt][j]);
|
||||
C[adr] = __float2half(vals[nt][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1814,7 +1751,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N, int GrpsShrB>
|
||||
int UNRL, int N, int GrpsShrB, int CHUNKK>
|
||||
__global__ void wvSplitKrc_(const int actlN, const int K, const int M,
|
||||
const int Bx, const int By, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A,
|
||||
@@ -1859,10 +1796,10 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// const int max_lds_len = get_lds_size() / 2;
|
||||
|
||||
#define WVSPLITKrc(_WvPrGrp, _YTILE, _UNRL, _N, _GrpsShrB) \
|
||||
#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
wvSplitKrc_<fptype, 64, _YTILE, _WvPrGrp, 8, _UNRL, _N, _GrpsShrB> \
|
||||
dim3 block(64, 4); \
|
||||
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK> \
|
||||
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, glbl, c, CuCount); \
|
||||
}
|
||||
@@ -1877,15 +1814,37 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
: nullptr;
|
||||
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
||||
auto glbl = axl_glbl.data_ptr<float>();
|
||||
|
||||
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
|
||||
// and each working on a 512-shard of K, how many CUs would we need?
|
||||
int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);
|
||||
|
||||
// How many of 4 waves in a group can work on same 16 Ms at same time? First
|
||||
// try to maximize this. This reduces the Ms each group works on, i.e.
|
||||
// increasing the number of CUs needed.
|
||||
int GrpsShrB = min(N_p2 / 16, 4);
|
||||
|
||||
// Given the above, how many CUs would we need?
|
||||
int CuNeeded = rndup_cus * GrpsShrB;
|
||||
|
||||
if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size");
|
||||
|
||||
// Can we increase SplitK by shrinking the K-shared to 256?
|
||||
int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;
|
||||
|
||||
switch (N_p2) {
|
||||
case 16:
|
||||
WVSPLITKrc(4, 16, 1, 16, 1) break;
|
||||
WVSPLITKrc(16, 1, 1) break;
|
||||
case 32:
|
||||
WVSPLITKrc(4, 16, 1, 32, 2) break;
|
||||
if (chunkk == 2)
|
||||
WVSPLITKrc(32, 2, 2) else if (chunkk == 1) WVSPLITKrc(32, 2, 1) break;
|
||||
case 64:
|
||||
WVSPLITKrc(4, 16, 1, 64, 2) break;
|
||||
if (chunkk == 2)
|
||||
WVSPLITKrc(64, 4, 2) else if (chunkk == 1) WVSPLITKrc(64, 4, 1) break;
|
||||
case 128:
|
||||
WVSPLITKrc(4, 16, 1, 128, 4) break;
|
||||
if (chunkk == 2)
|
||||
WVSPLITKrc(128, 4, 2) else if (chunkk == 1)
|
||||
WVSPLITKrc(128, 4, 1) break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Unsupported N value: " + std::to_string(M_in) + "," +
|
||||
|
||||
Reference in New Issue
Block a user