Perf tuning and expansion of cases covered for wvSplitKrc (#33493)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
@@ -1365,13 +1365,12 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
return out_c;
|
||||
}
|
||||
|
||||
#if defined(__gfx950__) // TODO: Add NAVI support
|
||||
// This version targets big A[] cases, where it is much larger than LDS
|
||||
// capacity
|
||||
// This version targets cases skinny where CUs are not filled
|
||||
// Wave-SplitK is used with reduction done via atomics.
|
||||
#if defined(__gfx950__)
|
||||
#define WVSPLITKRC_1KPASS
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N, int GrpsShrB>
|
||||
|
||||
int UNRL, int N, int GrpsShrB, int CHUNKK>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
__attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
wvSplitKrc_(const int actlN, const int K, const int M, const int Bx,
|
||||
@@ -1383,12 +1382,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
int* cntr = (int*)(&glbl[M * N]);
|
||||
|
||||
constexpr int NTILE = 16;
|
||||
constexpr int WVLDS_ = (NTILE * THRDS * A_CHUNK);
|
||||
constexpr int APAD = 1;
|
||||
constexpr int ASTRD = 64;
|
||||
constexpr int BPAD = 1;
|
||||
constexpr int BSTRD = 64;
|
||||
constexpr int WVLDS = ((WVLDS_ + (WVLDS_ / BSTRD) * 4 * BPAD));
|
||||
constexpr int WVLDS_ = THRDS * A_CHUNK / CHUNKK;
|
||||
constexpr int WVLDS = ((WVLDS_ + A_CHUNK * BPAD)) * YTILE;
|
||||
|
||||
constexpr int max_lds_len = LDS_SIZE / 2;
|
||||
|
||||
@@ -1442,17 +1440,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
break;
|
||||
}
|
||||
#else
|
||||
int constexpr kFit = 512;
|
||||
int constexpr kFit = 512 / CHUNKK;
|
||||
int constexpr kfitsPerRdc = 1;
|
||||
#endif
|
||||
|
||||
bool doRdc = (kfitsPerRdc * kFit < K);
|
||||
bool doRdc = true; // Assuming (kfitsPerRdc * kFit < K) is always true
|
||||
uint32_t numCuWithFullK =
|
||||
((M + (WvPrGrp * YTILE / GrpsShrB) - 1) / (WvPrGrp * YTILE / GrpsShrB));
|
||||
uint32_t Mmod = numCuWithFullK * (WvPrGrp * YTILE / GrpsShrB);
|
||||
|
||||
// given above k-split, find this wave's position
|
||||
uint32_t kFitPdd = kFit + (kFit / ASTRD) * APAD;
|
||||
uint32_t kFitPdd = kFit * CHUNKK + ((kFit * CHUNKK) / ASTRD) * APAD;
|
||||
uint32_t m0 = (blockIdx.x * WvPrGrp / GrpsShrB) * YTILE;
|
||||
uint32_t m1 = ((threadIdx.y % WvPrGrp) / GrpsShrB) * YTILE;
|
||||
uint32_t m = (m0 + m1) % Mmod;
|
||||
@@ -1460,8 +1458,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc;
|
||||
const uint32_t k_rnd = (K + kFit * kfitsPerRdc - 1) / (kFit * kfitsPerRdc);
|
||||
|
||||
scalar8 sum4[N / NTILE / GrpsShrB][1];
|
||||
bigType bigB_[YTILE / GrpsShrB][UNRL];
|
||||
scalar8 sum4[N / NTILE / GrpsShrB][1] = {0};
|
||||
bigType bigB_[YTILE / GrpsShrB / CHUNKK][UNRL];
|
||||
const uint32_t bLoader = (threadIdx.y % GrpsShrB);
|
||||
uint32_t kBase = 0;
|
||||
if (k_str >= K) return;
|
||||
@@ -1498,12 +1496,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k_str + k2 * THRDS * A_CHUNK;
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK;
|
||||
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||
#pragma unroll
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++)
|
||||
bigB_[y][k2].h8 = (loadnt(
|
||||
(scalar8*)(&B_[min__(y * GrpsShrB + bLoader + m, M - 1) * K])));
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK)
|
||||
bigB_[y / CHUNKK][k2].h8 = (loadnt(
|
||||
(scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB +
|
||||
bLoader + m,
|
||||
M - 1) *
|
||||
K])));
|
||||
}
|
||||
{
|
||||
#else
|
||||
@@ -1556,48 +1557,51 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
if (reloada) {
|
||||
#endif
|
||||
constexpr int sprdN = 4;
|
||||
const uint32_t thrd = ((threadIdx.y / sprdN) * THRDS + threadIdx.x);
|
||||
const uint32_t thrd = threadIdx.x % (THRDS / CHUNKK);
|
||||
|
||||
#ifndef WVSPLITKRC_1KPASS
|
||||
#pragma unroll
|
||||
for (int k = 0; k < kFit; k += THRDS * (WvPrGrp / sprdN) * A_CHUNK) {
|
||||
for (int k = 0; k < kFit;
|
||||
k += (THRDS * (WvPrGrp / sprdN) * A_CHUNK) / CHUNKK) {
|
||||
#else
|
||||
const unsigned int k = 0;
|
||||
{
|
||||
#endif
|
||||
unsigned int kOff = k + (thrd * A_CHUNK);
|
||||
unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff);
|
||||
const unsigned int k_in = kOffcp + ((threadIdx.y % sprdN)) * K;
|
||||
const unsigned int k_ot = kOff + ((threadIdx.y % sprdN)) * kFitPdd;
|
||||
for (unsigned int n = 0; n < N / 2; n += sprdN) {
|
||||
__builtin_amdgcn_global_load_lds((int*)(&A[k_in + n * K]),
|
||||
(int*)(&s[(k_ot + n * kFitPdd)]),
|
||||
16, 0, 0);
|
||||
if (((threadIdx.y % sprdN)) + n + N / 2 >= actlN) continue;
|
||||
unsigned int kOffcp =
|
||||
k_str + kOff; // min__(K - A_CHUNK, k_str + kOff);
|
||||
for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) {
|
||||
__builtin_amdgcn_global_load_lds(
|
||||
(int*)(&A[k_in + (n + N / 2) * K]),
|
||||
(int*)(&s[(k_ot + (n + N / 2) * kFitPdd)]), 16, 0, 0);
|
||||
(int*)(&A[min__(
|
||||
K * actlN - A_CHUNK,
|
||||
kOffcp + K * (n / CHUNKK +
|
||||
(N / CHUNKK) * (threadIdx.x / (64 / CHUNKK)) +
|
||||
(threadIdx.y % sprdN)))]),
|
||||
(int*)(&s[(k +
|
||||
kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]),
|
||||
16, 0, 0);
|
||||
}
|
||||
|
||||
// Stage loaded B[] to LDS for MFMA swizzling...
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK;
|
||||
const bool oob_k = (k_ >= K);
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++) {
|
||||
uint32_t idx = threadIdx.x * 4 +
|
||||
(y * GrpsShrB + bLoader) * ((THRDS + BPAD) * 4);
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) {
|
||||
uint32_t idx =
|
||||
(threadIdx.x % (THRDS / CHUNKK)) * 4 +
|
||||
((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB + bLoader) *
|
||||
((THRDS / CHUNKK + BPAD) * 4);
|
||||
// zero out if oob
|
||||
*((scalar8*)&myStg[idx]) =
|
||||
(oob_k || (y * GrpsShrB + bLoader + m >= M))
|
||||
(oob_k) // TODO: ever necessary (y*GrpsShrB+bLoader+m>=M) ?
|
||||
? 0
|
||||
: bigB_[y][k2].h8;
|
||||
: bigB_[y / CHUNKK][k2].h8;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef WVSPLITKRC_1KPASS
|
||||
// Fire load of next B[] chunk...
|
||||
if ((k1 + THRDS * A_CHUNK * UNRL < k_end) &&
|
||||
@@ -1608,40 +1612,50 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||
#pragma unroll
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++)
|
||||
bigB_[y][k2].h8 = (loadnt(
|
||||
(scalar8*)(&B_[min__(y * GrpsShrB + bLoader + m, M - 1) * K])));
|
||||
for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK)
|
||||
bigB_[y / CHUNKK][k2].h8 = (loadnt(
|
||||
(scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) *
|
||||
GrpsShrB +
|
||||
bLoader + m,
|
||||
M - 1) *
|
||||
K])));
|
||||
}
|
||||
#endif
|
||||
|
||||
// B[] staging is cooperative across GrpsShrB, so sync here before reading
|
||||
// back
|
||||
// back. This wait is currently inserted by compiler, but not gauranteed.
|
||||
asm volatile("s_waitcnt 0");
|
||||
__syncthreads();
|
||||
|
||||
// read back B[] swizzled for MFMA...
|
||||
bigType bigB[YTILE][UNRL];
|
||||
bigType bigB[YTILE / CHUNKK][UNRL];
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
for (uint32_t y = 0; y < YTILE; y++) {
|
||||
unsigned int idx = (threadIdx.x % YTILE) * ((THRDS + BPAD) * 4) +
|
||||
(threadIdx.x / YTILE) * 4 + y * 16;
|
||||
for (uint32_t y = 0; y < YTILE / CHUNKK; y++) {
|
||||
unsigned int idx =
|
||||
(threadIdx.x % YTILE) * ((THRDS / CHUNKK + BPAD) * 4) +
|
||||
(threadIdx.x / YTILE) * 4 + y * 16;
|
||||
bigB[y][k2].h8 = *((scalar8*)&myStg[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
// rReadback A[] swizzled for MFMA...
|
||||
bigType bigA[N / GrpsShrB][UNRL];
|
||||
bigType bigA[N / GrpsShrB / CHUNKK][UNRL];
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK - kBase - k_str;
|
||||
#pragma unroll
|
||||
for (uint32_t nt = 0; nt < N / GrpsShrB; nt += NTILE)
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < NTILE; n++) {
|
||||
uint32_t idxa = (nt + (threadIdx.x % NTILE) +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB)) *
|
||||
kFitPdd +
|
||||
A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k;
|
||||
bigA[nt + n][k2] = *((const bigType*)(&(s[idxa])));
|
||||
for (uint32_t n = 0; n < NTILE / CHUNKK; n++) {
|
||||
uint32_t idxa =
|
||||
((nt + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) % (N / CHUNKK) +
|
||||
(threadIdx.x % NTILE)) *
|
||||
kFitPdd +
|
||||
((nt + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) /
|
||||
(N / CHUNKK)) *
|
||||
A_CHUNK * (64 / CHUNKK) +
|
||||
A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k;
|
||||
bigA[nt / CHUNKK + n][k2] = *((const bigType*)(&(s[idxa])));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1650,152 +1664,75 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
#pragma unroll
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
bigA[nt * NTILE + 0][k2].h4[0], bigB[0][k2].h4[0],
|
||||
(k1 == k_str) ? ((scalar8){0}) : sum4[nt][0], 0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
bigA[nt * NTILE + 0][k2].h4[1], bigB[0][k2].h4[1], sum4[nt][0], 0,
|
||||
0, 0);
|
||||
} else { // bf16
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
bigA[nt * NTILE + 0][k2].h4[0], bigB[0][k2].h4[0],
|
||||
(k1 == k_str) ? ((scalar8){0}) : sum4[nt][0], 0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
bigA[nt * NTILE + 0][k2].h4[1], bigB[0][k2].h4[1], sum4[nt][0], 0,
|
||||
0, 0);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t j = 1; j < YTILE; j++) {
|
||||
for (uint32_t j = 0; j < YTILE / CHUNKK; j++) {
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
bigA[nt * NTILE + j][k2].h4[0], bigB[j][k2].h4[0], sum4[nt][0],
|
||||
0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
bigA[nt * NTILE + j][k2].h4[1], bigB[j][k2].h4[1], sum4[nt][0],
|
||||
0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x32_f16(
|
||||
bigA[nt * (YTILE / CHUNKK) + j][k2].h8, bigB[j][k2].h8,
|
||||
sum4[nt][0], 0, 0, 0);
|
||||
} else { // bf16
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
bigA[nt * NTILE + j][k2].h4[0], bigB[j][k2].h4[0], sum4[nt][0],
|
||||
0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
bigA[nt * NTILE + j][k2].h4[1], bigB[j][k2].h4[1], sum4[nt][0],
|
||||
0, 0, 0);
|
||||
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
|
||||
bigA[nt * (YTILE / CHUNKK) + j][k2].h8, bigB[j][k2].h8,
|
||||
sum4[nt][0], 0, 0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!doRdc) {
|
||||
if (m + (threadIdx.x % 16) < M) {
|
||||
scalar_t biases[N / NTILE / GrpsShrB][4] = {0};
|
||||
if (m + (threadIdx.x % 16) < M) {
|
||||
int my_cntr;
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction
|
||||
scalar_t biases[N / NTILE / GrpsShrB][4] = {};
|
||||
// Atomic add the output, read biases
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
// int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
// (N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
// int adr = mindx + M * nindx;
|
||||
int g_nindx =
|
||||
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
|
||||
int g_adr = g_mindx + M * g_nindx * 4;
|
||||
atomicAdd(&glbl[g_adr], sum4[nt][0][j]);
|
||||
}
|
||||
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr_ = mindx + M * nindx_ / 4;
|
||||
// Update the complete counter
|
||||
my_cntr = atomicAdd(&cntr[adr_], 1);
|
||||
float vals[N / NTILE / GrpsShrB][4] = {};
|
||||
// If we're the last k-shard, read back the value and convert...
|
||||
if (my_cntr + 1 == k_rnd) {
|
||||
if (BIAS)
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * M];
|
||||
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx];
|
||||
}
|
||||
}
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr = mindx + M * nindx;
|
||||
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS) sum4[nt][0][j] += __bfloat162float(biases[nt][j]);
|
||||
C[adr] = __float2bfloat16(sum4[nt][0][j]);
|
||||
} else {
|
||||
if (BIAS) sum4[nt][0][j] += __half2float(biases[nt][j]);
|
||||
C[adr] = __float2half(sum4[nt][0][j]);
|
||||
}
|
||||
int g_nindx =
|
||||
j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4;
|
||||
int g_adr = g_mindx + M * g_nindx * 4;
|
||||
vals[nt][j] = glbl[g_adr];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (m + (threadIdx.x % 16) < M) {
|
||||
int my_cntr;
|
||||
if (!BIAS) {
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
if (nindx < actlN) {
|
||||
int adr = mindx + M * nindx;
|
||||
atomicAdd(&glbl[adr], sum4[nt][0][j]);
|
||||
}
|
||||
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr_ = mindx + M * nindx_ / 4;
|
||||
my_cntr = atomicAdd(&cntr[adr_], 1);
|
||||
float vals[N / NTILE / GrpsShrB][4] = {};
|
||||
if (my_cntr + 1 == k_rnd) {
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr = mindx + M * nindx;
|
||||
vals[nt][j] = glbl[adr];
|
||||
}
|
||||
}
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
if (nindx >= actlN) break;
|
||||
int adr = mindx + M * nindx;
|
||||
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
C[adr] = __float2bfloat16(vals[nt][j]);
|
||||
} else {
|
||||
C[adr] = __float2half(vals[nt][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int mindx = m + (threadIdx.x % 16);
|
||||
scalar_t biases[N / NTILE / GrpsShrB][4] = {};
|
||||
// Atomic add the output, read biases
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr = mindx + M * nindx;
|
||||
atomicAdd(&glbl[adr], sum4[nt][0][j]);
|
||||
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * M];
|
||||
}
|
||||
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr_ = mindx + M * nindx_ / 4;
|
||||
// Update the complete counter
|
||||
my_cntr = atomicAdd(&cntr[adr_], 1);
|
||||
float vals[N / NTILE / GrpsShrB][4] = {};
|
||||
// If we're the last k-shard, read back the value and convert...
|
||||
if (my_cntr + 1 == k_rnd) {
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
int adr = mindx + M * nindx;
|
||||
vals[nt][j] = glbl[adr];
|
||||
}
|
||||
}
|
||||
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
||||
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
||||
if (nindx >= actlN) break;
|
||||
int adr = mindx + M * nindx;
|
||||
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
vals[nt][j] += __bfloat162float(biases[nt][j]);
|
||||
C[adr] = __float2bfloat16(vals[nt][j]);
|
||||
} else {
|
||||
vals[nt][j] += __half2float(biases[nt][j]);
|
||||
C[adr] = __float2half(vals[nt][j]);
|
||||
}
|
||||
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
vals[nt][j] += __bfloat162float(biases[nt][j]);
|
||||
C[adr] = __float2bfloat16(vals[nt][j]);
|
||||
} else {
|
||||
vals[nt][j] += __half2float(biases[nt][j]);
|
||||
C[adr] = __float2half(vals[nt][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1814,7 +1751,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||
int UNRL, int N, int GrpsShrB>
|
||||
int UNRL, int N, int GrpsShrB, int CHUNKK>
|
||||
__global__ void wvSplitKrc_(const int actlN, const int K, const int M,
|
||||
const int Bx, const int By, const scalar_t* B,
|
||||
const scalar_t* __restrict__ A,
|
||||
@@ -1859,10 +1796,10 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// const int max_lds_len = get_lds_size() / 2;
|
||||
|
||||
#define WVSPLITKrc(_WvPrGrp, _YTILE, _UNRL, _N, _GrpsShrB) \
|
||||
#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
wvSplitKrc_<fptype, 64, _YTILE, _WvPrGrp, 8, _UNRL, _N, _GrpsShrB> \
|
||||
dim3 block(64, 4); \
|
||||
wvSplitKrc_<fptype, 64, 16, 4, 8, 1, _N, _GrpsShrB, _CHUNKK> \
|
||||
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||
biasf4, glbl, c, CuCount); \
|
||||
}
|
||||
@@ -1877,15 +1814,37 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
: nullptr;
|
||||
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
||||
auto glbl = axl_glbl.data_ptr<float>();
|
||||
|
||||
// With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
|
||||
// and each working on a 512-shard of K, how many CUs would we need?
|
||||
int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512);
|
||||
|
||||
// How many of 4 waves in a group can work on same 16 Ms at same time? First
|
||||
// try to maximize this. This reduces the Ms each group works on, i.e.
|
||||
// increasing the number of CUs needed.
|
||||
int GrpsShrB = min(N_p2 / 16, 4);
|
||||
|
||||
// Given the above, how many CUs would we need?
|
||||
int CuNeeded = rndup_cus * GrpsShrB;
|
||||
|
||||
if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size");
|
||||
|
||||
// Can we increase SplitK by shrinking the K-shared to 256?
|
||||
int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1;
|
||||
|
||||
switch (N_p2) {
|
||||
case 16:
|
||||
WVSPLITKrc(4, 16, 1, 16, 1) break;
|
||||
WVSPLITKrc(16, 1, 1) break;
|
||||
case 32:
|
||||
WVSPLITKrc(4, 16, 1, 32, 2) break;
|
||||
if (chunkk == 2)
|
||||
WVSPLITKrc(32, 2, 2) else if (chunkk == 1) WVSPLITKrc(32, 2, 1) break;
|
||||
case 64:
|
||||
WVSPLITKrc(4, 16, 1, 64, 2) break;
|
||||
if (chunkk == 2)
|
||||
WVSPLITKrc(64, 4, 2) else if (chunkk == 1) WVSPLITKrc(64, 4, 1) break;
|
||||
case 128:
|
||||
WVSPLITKrc(4, 16, 1, 128, 4) break;
|
||||
if (chunkk == 2)
|
||||
WVSPLITKrc(128, 4, 2) else if (chunkk == 1)
|
||||
WVSPLITKrc(128, 4, 1) break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Unsupported N value: " + std::to_string(M_in) + "," +
|
||||
|
||||
@@ -45,31 +45,28 @@ NKM_FACTORS_WVSPLITK = [
|
||||
(4, 256, 8),
|
||||
]
|
||||
|
||||
NKM_FACTORS_WVSPLITKRC = [
|
||||
(16, 2880, 128),
|
||||
(16, 2880, 640),
|
||||
(17, 2880, 128),
|
||||
(17, 2880, 640),
|
||||
(25, 2880, 128),
|
||||
(25, 2880, 640),
|
||||
(31, 2880, 128),
|
||||
(31, 2880, 640),
|
||||
(32, 2880, 128),
|
||||
(32, 2880, 640),
|
||||
(40, 2880, 128),
|
||||
(40, 2880, 640),
|
||||
(60, 2880, 128),
|
||||
(60, 2880, 640),
|
||||
(64, 2880, 128),
|
||||
(64, 2880, 640),
|
||||
(81, 2880, 128),
|
||||
(81, 2880, 640),
|
||||
(98, 2880, 128),
|
||||
(98, 2880, 640),
|
||||
(128, 2880, 128),
|
||||
(128, 2880, 640),
|
||||
N_FACTORS_WVSPLITKRC = [
|
||||
13,
|
||||
16,
|
||||
17,
|
||||
25,
|
||||
29,
|
||||
31,
|
||||
32,
|
||||
41,
|
||||
51,
|
||||
64,
|
||||
71,
|
||||
81,
|
||||
91,
|
||||
103,
|
||||
117,
|
||||
128,
|
||||
]
|
||||
|
||||
K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8]
|
||||
M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16]
|
||||
|
||||
NKM_FACTORS_WVSPLITK_FP8 = [
|
||||
# FP8-specific cases with K % 16 == 0
|
||||
(1, 16, 16),
|
||||
@@ -113,30 +110,54 @@ def pad_fp8(weight):
|
||||
return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC)
|
||||
@pytest.mark.parametrize("xnorm", [False, True])
|
||||
@pytest.mark.parametrize("n", N_FACTORS_WVSPLITKRC)
|
||||
@pytest.mark.parametrize("k", K_FACTORS_WVSPLITKRC)
|
||||
@pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950")
|
||||
def test_rocm_wvsplitkrc_kernel(n, k, m, dtype, seed, bias_mode):
|
||||
def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
# Next ^2 of n
|
||||
N_p2 = 1 << (n - 1).bit_length()
|
||||
# With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
|
||||
# and each working on a 512-shard of K, how many CUs would we need?
|
||||
rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
|
||||
# How many of 4 waves in a group can work on same 16 Ms at same time?
|
||||
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
|
||||
GrpsShrB = min(N_p2 // 16, 4)
|
||||
# Given the above, how many CUs would we need?
|
||||
CuNeeded = rndup_cus * GrpsShrB
|
||||
# candidate for atomic reduce count splitk?
|
||||
fits_wvsplitkrc = CuNeeded <= cu_count
|
||||
|
||||
if not fits_wvsplitkrc:
|
||||
pytest.skip("Too large for wvSplitKrc")
|
||||
|
||||
xavier = (
|
||||
math.sqrt(2 / k) if xnorm else 1
|
||||
) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
|
||||
B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
|
||||
|
||||
BIAS = None
|
||||
if bias_mode == 1:
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
|
||||
elif bias_mode == 2:
|
||||
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
|
||||
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
|
||||
|
||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||
out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
if xnorm:
|
||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
|
||||
else:
|
||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
|
||||
|
||||
@@ -145,32 +145,43 @@ def rocm_unquantized_gemm_impl(
|
||||
) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_gfx9, on_gfx950
|
||||
|
||||
n = x.numel() / x.size(-1)
|
||||
n = x.numel() // x.size(-1)
|
||||
m = weight.shape[0]
|
||||
k = weight.shape[1]
|
||||
|
||||
import math
|
||||
|
||||
cu_count = get_cu_count()
|
||||
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||
|
||||
return gemm_a16w16(x, weight, bias)
|
||||
|
||||
# Next ^2 of n
|
||||
N_p2 = 1 << (n - 1).bit_length()
|
||||
# With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
|
||||
# and each working on a 512-shard of K, how many CUs would we need?
|
||||
rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
|
||||
# How many of 4 waves in a group can work on same 16 Ms at same time?
|
||||
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
|
||||
GrpsShrB = min(N_p2 // 16, 4)
|
||||
# Given the above, how many CUs would we need?
|
||||
CuNeeded = rndup_cus * GrpsShrB
|
||||
# candidate for atomic reduce count splitk?
|
||||
fits_wvsplitkrc = CuNeeded <= cu_count
|
||||
|
||||
use_skinny_reduce_counting = (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_gfx950()
|
||||
and x.dtype in [torch.float16, torch.bfloat16]
|
||||
and (
|
||||
n >= 16
|
||||
and n <= 128
|
||||
10 <= n <= 128
|
||||
and k % 8 == 0
|
||||
and k > 512
|
||||
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
|
||||
and m % 16 == 0
|
||||
and fits_wvsplitkrc
|
||||
and x.is_contiguous()
|
||||
)
|
||||
# k == 2880 and (m == 640 or m == 128))
|
||||
)
|
||||
if use_skinny_reduce_counting:
|
||||
cu_count = get_cu_count()
|
||||
x_view = x.reshape(-1, x.size(-1))
|
||||
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
|
||||
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||
|
||||
Reference in New Issue
Block a user