|
|
|
|
@@ -287,6 +287,11 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
|
|
|
|
V0 += (s.x + s.y); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// To avoid LLVM silently upcasting to double
|
|
|
|
|
__device__ inline unsigned int min__(uint32_t a, uint32_t b) {
|
|
|
|
|
return min(a, b);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
|
|
|
|
// This version targets cases where A[] fits LDS capacity
|
|
|
|
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
|
|
|
|
@@ -334,11 +339,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
// - Then the WG will move to another 8 K elements
|
|
|
|
|
// TODO: Logic below will only work when K is multiple of 8
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
|
|
|
|
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
|
|
|
|
|
k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
|
|
|
|
|
|
|
|
|
if (k_in >= min(K * N, max_lds_len)) break;
|
|
|
|
|
if (k_in >= min__(K * N, max_lds_len)) break;
|
|
|
|
|
|
|
|
|
|
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
|
|
|
|
}
|
|
|
|
|
@@ -633,11 +638,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
// - Then the WG will move to another 8 K elements
|
|
|
|
|
// TODO: Logic below will only work when K is multiple of 8
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
|
|
|
|
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
|
|
|
|
|
k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
|
|
|
|
|
|
|
|
|
if (k_in >= min(K * N, max_lds_len)) break;
|
|
|
|
|
if (k_in >= min__(K * N, max_lds_len)) break;
|
|
|
|
|
|
|
|
|
|
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
|
|
|
|
}
|
|
|
|
|
@@ -954,11 +959,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
//----------------------------------------------------
|
|
|
|
|
#define PCML
|
|
|
|
|
#ifndef PCML
|
|
|
|
|
for (uint32_t k = 0; k < min(K * N, max_lds_len);
|
|
|
|
|
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
|
|
|
|
|
k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
|
|
|
|
|
|
|
|
|
|
if (k_in >= min(K * N, max_lds_len)) break;
|
|
|
|
|
if (k_in >= min__(K * N, max_lds_len)) break;
|
|
|
|
|
|
|
|
|
|
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
|
|
|
|
|
}
|
|
|
|
|
@@ -975,7 +980,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
? kFit
|
|
|
|
|
: (kFit - kFit % TUC); // round up to multiple of TUC
|
|
|
|
|
// if (kFit == 0) kFit = TUC;
|
|
|
|
|
kFit = min(kFit, K);
|
|
|
|
|
kFit = min__(kFit, K);
|
|
|
|
|
|
|
|
|
|
float sum[N][YTILE];
|
|
|
|
|
scalar8 sum4[N][YTILE];
|
|
|
|
|
@@ -1251,6 +1256,7 @@ int mindiv(int N, int div1, int div2) {
|
|
|
|
|
}
|
|
|
|
|
for (int i = 12; i >= 0; i--)
|
|
|
|
|
if (rnds[0] == rnds[i]) return (div2 - i);
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
|
|
|
|
@@ -1352,6 +1358,536 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
|
|
|
|
return out_c;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(__gfx950__) // TODO: Add NAVI support
|
|
|
|
|
// This version targets big A[] cases, where it is much larger than LDS
|
|
|
|
|
// capacity
|
|
|
|
|
#define WVSPLITKRC_1KPASS
|
|
|
|
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
|
|
|
|
int UNRL, int N, int GrpsShrB>
|
|
|
|
|
|
|
|
|
|
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
__attribute__((amdgpu_waves_per_eu(1, 1)))
|
|
|
|
|
wvSplitKrc_(const int actlN, const int K, const int M, const int Bx,
|
|
|
|
|
const int By, const scalar_t* __restrict__ B,
|
|
|
|
|
const scalar_t* __restrict__ A,
|
|
|
|
|
const scalar_t* __restrict__ BIAS, float* glbl, scalar_t* C,
|
|
|
|
|
const int CuCount) {
|
|
|
|
|
// Use upper half of glbl buffer for atomic reduce counting
|
|
|
|
|
int* cntr = (int*)(&glbl[M * N]);
|
|
|
|
|
|
|
|
|
|
constexpr int NTILE = 16;
|
|
|
|
|
constexpr int WVLDS_ = (NTILE * THRDS * A_CHUNK);
|
|
|
|
|
constexpr int APAD = 1;
|
|
|
|
|
constexpr int ASTRD = 64;
|
|
|
|
|
constexpr int BPAD = 1;
|
|
|
|
|
constexpr int BSTRD = 64;
|
|
|
|
|
constexpr int WVLDS = ((WVLDS_ + (WVLDS_ / BSTRD) * 4 * BPAD));
|
|
|
|
|
|
|
|
|
|
constexpr int max_lds_len = LDS_SIZE / 2;
|
|
|
|
|
|
|
|
|
|
using scalar16 =
|
|
|
|
|
__attribute__((__vector_size__((A_CHUNK * 2) * sizeof(float)))) float;
|
|
|
|
|
using scalar8 =
|
|
|
|
|
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
|
|
|
|
|
using half4 =
|
|
|
|
|
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
|
|
|
|
|
union bigType {
|
|
|
|
|
scalar_t h[A_CHUNK];
|
|
|
|
|
float f[A_CHUNK / 2];
|
|
|
|
|
unsigned int i[A_CHUNK / 2];
|
|
|
|
|
float2 f2[A_CHUNK / 4];
|
|
|
|
|
unsigned long l[A_CHUNK / 4];
|
|
|
|
|
double d[A_CHUNK / 4];
|
|
|
|
|
half4 h4[A_CHUNK / 4];
|
|
|
|
|
scalar8 h8;
|
|
|
|
|
};
|
|
|
|
|
using big4 = __attribute__((__vector_size__(4 * sizeof(bigType)))) __bf16;
|
|
|
|
|
|
|
|
|
|
__shared__ scalar_t stg[WvPrGrp * WVLDS / GrpsShrB];
|
|
|
|
|
unsigned int* myStg = (unsigned int*)(&stg[WVLDS * (threadIdx.y / GrpsShrB)]);
|
|
|
|
|
__shared__ scalar_t s[max_lds_len - WvPrGrp * WVLDS / GrpsShrB];
|
|
|
|
|
|
|
|
|
|
#ifndef WVSPLITKRC_1KPASS
|
|
|
|
|
constexpr int TUC_ = (THRDS * UNRL * A_CHUNK);
|
|
|
|
|
// find biggest k size that fits padded into LDS
|
|
|
|
|
constexpr uint32_t kFit__ = (max_lds_len - WvPrGrp * WVLDS / GrpsShrB) / N;
|
|
|
|
|
constexpr uint32_t kFit_ = (kFit__ * ASTRD) / (APAD + ASTRD);
|
|
|
|
|
uint32_t kFit = kFit_ - (kFit_ % TUC_);
|
|
|
|
|
uint32_t kfitsPerRdc = (K + kFit - 1) / kFit;
|
|
|
|
|
|
|
|
|
|
// find best k split to fill the CUs
|
|
|
|
|
if (((K + kfitsPerRdc * kFit - 1) / (kfitsPerRdc * kFit)) * numCuWithFullK <=
|
|
|
|
|
CuCount)
|
|
|
|
|
while (true) {
|
|
|
|
|
while (kFit > TUC_) {
|
|
|
|
|
uint32_t kFit_ = kFit - TUC_;
|
|
|
|
|
if (((K + (kfitsPerRdc * kFit_ - 1)) / (kfitsPerRdc * kFit_)) *
|
|
|
|
|
numCuWithFullK >
|
|
|
|
|
CuCount)
|
|
|
|
|
break;
|
|
|
|
|
kFit = kFit_;
|
|
|
|
|
}
|
|
|
|
|
if (((K + ((kfitsPerRdc - 1) * kFit - 1)) / ((kfitsPerRdc - 1) * kFit)) *
|
|
|
|
|
numCuWithFullK <=
|
|
|
|
|
CuCount)
|
|
|
|
|
kfitsPerRdc--;
|
|
|
|
|
else
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
int constexpr kFit = 512;
|
|
|
|
|
int constexpr kfitsPerRdc = 1;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
bool doRdc = (kfitsPerRdc * kFit < K);
|
|
|
|
|
uint32_t numCuWithFullK =
|
|
|
|
|
((M + (WvPrGrp * YTILE / GrpsShrB) - 1) / (WvPrGrp * YTILE / GrpsShrB));
|
|
|
|
|
uint32_t Mmod = numCuWithFullK * (WvPrGrp * YTILE / GrpsShrB);
|
|
|
|
|
|
|
|
|
|
// given above k-split, find this wave's position
|
|
|
|
|
uint32_t kFitPdd = kFit + (kFit / ASTRD) * APAD;
|
|
|
|
|
uint32_t m0 = (blockIdx.x * WvPrGrp / GrpsShrB) * YTILE;
|
|
|
|
|
uint32_t m1 = ((threadIdx.y % WvPrGrp) / GrpsShrB) * YTILE;
|
|
|
|
|
uint32_t m = (m0 + m1) % Mmod;
|
|
|
|
|
const uint32_t k_str = (m0 / Mmod) * kFit * kfitsPerRdc;
|
|
|
|
|
uint32_t k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc;
|
|
|
|
|
const uint32_t k_rnd = (K + kFit * kfitsPerRdc - 1) / (kFit * kfitsPerRdc);
|
|
|
|
|
|
|
|
|
|
scalar8 sum4[N / NTILE / GrpsShrB][1];
|
|
|
|
|
bigType bigB_[YTILE / GrpsShrB][UNRL];
|
|
|
|
|
const uint32_t bLoader = (threadIdx.y % GrpsShrB);
|
|
|
|
|
uint32_t kBase = 0;
|
|
|
|
|
if (k_str >= K) return;
|
|
|
|
|
if (m >= Mmod) return;
|
|
|
|
|
|
|
|
|
|
bool noreloada = false;
|
|
|
|
|
constexpr bool FAST_UNSAFE_RDC_INIT = false;
|
|
|
|
|
|
|
|
|
|
#ifdef WVSPLITKRC_1KPASS
|
|
|
|
|
// Early glbl init, B[] loading, if 1KPASS
|
|
|
|
|
if constexpr (FAST_UNSAFE_RDC_INIT) {
|
|
|
|
|
if (m + (threadIdx.x % 16) < M)
|
|
|
|
|
if (doRdc)
|
|
|
|
|
if (k_str == 0) {
|
|
|
|
|
int mindx = m + (threadIdx.x % 16);
|
|
|
|
|
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr_ = mindx + M * nindx_ / 4;
|
|
|
|
|
__hip_atomic_store(&cntr[adr_], 0, __ATOMIC_RELAXED,
|
|
|
|
|
__HIP_MEMORY_SCOPE_AGENT);
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
|
|
|
|
for (uint32_t j = 0; j < 4; j++) {
|
|
|
|
|
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr = mindx + M * nindx;
|
|
|
|
|
__hip_atomic_store(&glbl[adr], 0, __ATOMIC_RELAXED,
|
|
|
|
|
__HIP_MEMORY_SCOPE_AGENT);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Load first B[] chunk
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
|
|
|
|
uint32_t k = k_str + k2 * THRDS * A_CHUNK;
|
|
|
|
|
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
|
|
|
|
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++)
|
|
|
|
|
bigB_[y][k2].h8 = (loadnt(
|
|
|
|
|
(scalar8*)(&B_[min__(y * GrpsShrB + bLoader + m, M - 1) * K])));
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
#else
|
|
|
|
|
while (m < Mmod) {
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifndef WVSPLITKRC_1KPASS
|
|
|
|
|
if constexpr (FAST_UNSAFE_RDC_INIT) {
|
|
|
|
|
if (m + (threadIdx.x % 16) < M)
|
|
|
|
|
if (doRdc)
|
|
|
|
|
if (k_str == 0) {
|
|
|
|
|
int mindx = m + (threadIdx.x % 16);
|
|
|
|
|
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr_ = mindx + M * nindx_ / 4;
|
|
|
|
|
__hip_atomic_store(&cntr[adr_], 0, __ATOMIC_RELAXED,
|
|
|
|
|
__HIP_MEMORY_SCOPE_AGENT);
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
|
|
|
|
for (uint32_t j = 0; j < 4; j++) {
|
|
|
|
|
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr = mindx + M * nindx;
|
|
|
|
|
__hip_atomic_store(&glbl[adr], 0, __ATOMIC_RELAXED,
|
|
|
|
|
__HIP_MEMORY_SCOPE_AGENT);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifndef WVSPLITKRC_1KPASS
|
|
|
|
|
for (uint32_t k1 = k_str; k1 < k_end; k1 += THRDS * A_CHUNK * UNRL) {
|
|
|
|
|
#else
|
|
|
|
|
const uint32_t k1 = k_str;
|
|
|
|
|
{
|
|
|
|
|
#endif
|
|
|
|
|
#ifndef WVSPLITKRC_1KPASS
|
|
|
|
|
const bool reloada = (!noreloada) &&
|
|
|
|
|
((k1 == k_str) || (k1 == k_str + kBase + kFit)) &&
|
|
|
|
|
(k1 < k_end);
|
|
|
|
|
// load next chunk of A[] to LDS
|
|
|
|
|
if (reloada) {
|
|
|
|
|
if (k1 != k_str) kBase += kFit;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
#else
|
|
|
|
|
const bool reloada = (!noreloada) &&
|
|
|
|
|
((k1 == k_str) || (k1 == k_str + kBase + kFit)) &&
|
|
|
|
|
(k1 < k_end);
|
|
|
|
|
if (reloada) {
|
|
|
|
|
#endif
|
|
|
|
|
constexpr int sprdN = 4;
|
|
|
|
|
const uint32_t thrd = ((threadIdx.y / sprdN) * THRDS + threadIdx.x);
|
|
|
|
|
|
|
|
|
|
#ifndef WVSPLITKRC_1KPASS
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int k = 0; k < kFit; k += THRDS * (WvPrGrp / sprdN) * A_CHUNK) {
|
|
|
|
|
#else
|
|
|
|
|
const unsigned int k = 0;
|
|
|
|
|
{
|
|
|
|
|
#endif
|
|
|
|
|
unsigned int kOff = k + (thrd * A_CHUNK);
|
|
|
|
|
unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff);
|
|
|
|
|
const unsigned int k_in = kOffcp + ((threadIdx.y % sprdN)) * K;
|
|
|
|
|
const unsigned int k_ot = kOff + ((threadIdx.y % sprdN)) * kFitPdd;
|
|
|
|
|
for (unsigned int n = 0; n < N / 2; n += sprdN) {
|
|
|
|
|
__builtin_amdgcn_global_load_lds((int*)(&A[k_in + n * K]),
|
|
|
|
|
(int*)(&s[(k_ot + n * kFitPdd)]),
|
|
|
|
|
16, 0, 0);
|
|
|
|
|
if (((threadIdx.y % sprdN)) + n + N / 2 >= actlN) continue;
|
|
|
|
|
__builtin_amdgcn_global_load_lds(
|
|
|
|
|
(int*)(&A[k_in + (n + N / 2) * K]),
|
|
|
|
|
(int*)(&s[(k_ot + (n + N / 2) * kFitPdd)]), 16, 0, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Stage loaded B[] to LDS for MFMA swizzling...
|
|
|
|
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
|
|
|
|
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
|
|
|
|
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
|
|
|
|
const bool oob_k = (k_ >= K);
|
|
|
|
|
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++) {
|
|
|
|
|
uint32_t idx = threadIdx.x * 4 +
|
|
|
|
|
(y * GrpsShrB + bLoader) * ((THRDS + BPAD) * 4);
|
|
|
|
|
// zero out if oob
|
|
|
|
|
*((scalar8*)&myStg[idx]) =
|
|
|
|
|
(oob_k || (y * GrpsShrB + bLoader + m >= M))
|
|
|
|
|
? 0
|
|
|
|
|
: bigB_[y][k2].h8;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef WVSPLITKRC_1KPASS
|
|
|
|
|
// Fire load of next B[] chunk...
|
|
|
|
|
if ((k1 + THRDS * A_CHUNK * UNRL < k_end) &&
|
|
|
|
|
(k1 + THRDS * A_CHUNK * UNRL < K))
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
|
|
|
|
uint32_t k = k1 + THRDS * A_CHUNK * UNRL + k2 * THRDS * A_CHUNK;
|
|
|
|
|
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
|
|
|
|
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (uint32_t y = 0; y < YTILE / GrpsShrB; y++)
|
|
|
|
|
bigB_[y][k2].h8 = (loadnt(
|
|
|
|
|
(scalar8*)(&B_[min__(y * GrpsShrB + bLoader + m, M - 1) * K])));
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// B[] staging is cooperative across GrpsShrB, so sync here before reading
|
|
|
|
|
// back
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
// read back B[] swizzled for MFMA...
|
|
|
|
|
bigType bigB[YTILE][UNRL];
|
|
|
|
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
|
|
|
|
for (uint32_t y = 0; y < YTILE; y++) {
|
|
|
|
|
unsigned int idx = (threadIdx.x % YTILE) * ((THRDS + BPAD) * 4) +
|
|
|
|
|
(threadIdx.x / YTILE) * 4 + y * 16;
|
|
|
|
|
bigB[y][k2].h8 = *((scalar8*)&myStg[idx]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// rReadback A[] swizzled for MFMA...
|
|
|
|
|
bigType bigA[N / GrpsShrB][UNRL];
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
|
|
|
|
uint32_t k = k1 + k2 * THRDS * A_CHUNK - kBase - k_str;
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (uint32_t nt = 0; nt < N / GrpsShrB; nt += NTILE)
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (uint32_t n = 0; n < NTILE; n++) {
|
|
|
|
|
uint32_t idxa = (nt + (threadIdx.x % NTILE) +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB)) *
|
|
|
|
|
kFitPdd +
|
|
|
|
|
A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k;
|
|
|
|
|
bigA[nt + n][k2] = *((const bigType*)(&(s[idxa])));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Do the MFMAs
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
|
|
|
|
if constexpr (std::is_same_v<scalar_t, half>) {
|
|
|
|
|
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
|
|
|
|
bigA[nt * NTILE + 0][k2].h4[0], bigB[0][k2].h4[0],
|
|
|
|
|
(k1 == k_str) ? ((scalar8){0}) : sum4[nt][0], 0, 0, 0);
|
|
|
|
|
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
|
|
|
|
bigA[nt * NTILE + 0][k2].h4[1], bigB[0][k2].h4[1], sum4[nt][0], 0,
|
|
|
|
|
0, 0);
|
|
|
|
|
} else { // bf16
|
|
|
|
|
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
|
|
|
|
bigA[nt * NTILE + 0][k2].h4[0], bigB[0][k2].h4[0],
|
|
|
|
|
(k1 == k_str) ? ((scalar8){0}) : sum4[nt][0], 0, 0, 0);
|
|
|
|
|
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
|
|
|
|
bigA[nt * NTILE + 0][k2].h4[1], bigB[0][k2].h4[1], sum4[nt][0], 0,
|
|
|
|
|
0, 0);
|
|
|
|
|
}
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (uint32_t j = 1; j < YTILE; j++) {
|
|
|
|
|
if constexpr (std::is_same_v<scalar_t, half>) {
|
|
|
|
|
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
|
|
|
|
bigA[nt * NTILE + j][k2].h4[0], bigB[j][k2].h4[0], sum4[nt][0],
|
|
|
|
|
0, 0, 0);
|
|
|
|
|
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
|
|
|
|
bigA[nt * NTILE + j][k2].h4[1], bigB[j][k2].h4[1], sum4[nt][0],
|
|
|
|
|
0, 0, 0);
|
|
|
|
|
} else { // bf16
|
|
|
|
|
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
|
|
|
|
bigA[nt * NTILE + j][k2].h4[0], bigB[j][k2].h4[0], sum4[nt][0],
|
|
|
|
|
0, 0, 0);
|
|
|
|
|
sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
|
|
|
|
bigA[nt * NTILE + j][k2].h4[1], bigB[j][k2].h4[1], sum4[nt][0],
|
|
|
|
|
0, 0, 0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!doRdc) {
|
|
|
|
|
if (m + (threadIdx.x % 16) < M) {
|
|
|
|
|
scalar_t biases[N / NTILE / GrpsShrB][4] = {0};
|
|
|
|
|
if (BIAS)
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
|
|
|
|
for (uint32_t j = 0; j < 4; j++) {
|
|
|
|
|
int mindx = m + (threadIdx.x % 16);
|
|
|
|
|
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * M];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
|
|
|
|
for (uint32_t j = 0; j < 4; j++) {
|
|
|
|
|
int mindx = m + (threadIdx.x % 16);
|
|
|
|
|
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr = mindx + M * nindx;
|
|
|
|
|
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
|
|
|
|
if (BIAS) sum4[nt][0][j] += __bfloat162float(biases[nt][j]);
|
|
|
|
|
C[adr] = __float2bfloat16(sum4[nt][0][j]);
|
|
|
|
|
} else {
|
|
|
|
|
if (BIAS) sum4[nt][0][j] += __half2float(biases[nt][j]);
|
|
|
|
|
C[adr] = __float2half(sum4[nt][0][j]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (m + (threadIdx.x % 16) < M) {
|
|
|
|
|
int my_cntr;
|
|
|
|
|
if (!BIAS) {
|
|
|
|
|
int mindx = m + (threadIdx.x % 16);
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
|
|
|
|
|
for (uint32_t j = 0; j < 4; j++) {
|
|
|
|
|
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr = mindx + M * nindx;
|
|
|
|
|
atomicAdd(&glbl[adr], sum4[nt][0][j]);
|
|
|
|
|
}
|
|
|
|
|
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr_ = mindx + M * nindx_ / 4;
|
|
|
|
|
my_cntr = atomicAdd(&cntr[adr_], 1);
|
|
|
|
|
float vals[N / NTILE / GrpsShrB][4] = {};
|
|
|
|
|
if (my_cntr + 1 == k_rnd) {
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
|
|
|
|
for (uint32_t j = 0; j < 4; j++) {
|
|
|
|
|
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr = mindx + M * nindx;
|
|
|
|
|
vals[nt][j] = glbl[adr];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
|
|
|
|
for (uint32_t j = 0; j < 4; j++) {
|
|
|
|
|
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
if (nindx >= actlN) break;
|
|
|
|
|
int adr = mindx + M * nindx;
|
|
|
|
|
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
|
|
|
|
C[adr] = __float2bfloat16(vals[nt][j]);
|
|
|
|
|
} else {
|
|
|
|
|
C[adr] = __float2half(vals[nt][j]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
int mindx = m + (threadIdx.x % 16);
|
|
|
|
|
scalar_t biases[N / NTILE / GrpsShrB][4] = {};
|
|
|
|
|
// Atomic add the output, read biases
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++)
|
|
|
|
|
for (uint32_t j = 0; j < 4; j++) {
|
|
|
|
|
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr = mindx + M * nindx;
|
|
|
|
|
atomicAdd(&glbl[adr], sum4[nt][0][j]);
|
|
|
|
|
biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * M];
|
|
|
|
|
}
|
|
|
|
|
int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr_ = mindx + M * nindx_ / 4;
|
|
|
|
|
// Update the complete counter
|
|
|
|
|
my_cntr = atomicAdd(&cntr[adr_], 1);
|
|
|
|
|
float vals[N / NTILE / GrpsShrB][4] = {};
|
|
|
|
|
// If we're the last k-shard, read back the value and convert...
|
|
|
|
|
if (my_cntr + 1 == k_rnd) {
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
|
|
|
|
for (uint32_t j = 0; j < 4; j++) {
|
|
|
|
|
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
int adr = mindx + M * nindx;
|
|
|
|
|
vals[nt][j] = glbl[adr];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) {
|
|
|
|
|
for (uint32_t j = 0; j < 4; j++) {
|
|
|
|
|
int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE +
|
|
|
|
|
(N / GrpsShrB) * (threadIdx.y % GrpsShrB);
|
|
|
|
|
if (nindx >= actlN) break;
|
|
|
|
|
int adr = mindx + M * nindx;
|
|
|
|
|
if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
|
|
|
|
vals[nt][j] += __bfloat162float(biases[nt][j]);
|
|
|
|
|
C[adr] = __float2bfloat16(vals[nt][j]);
|
|
|
|
|
} else {
|
|
|
|
|
vals[nt][j] += __half2float(biases[nt][j]);
|
|
|
|
|
C[adr] = __float2half(vals[nt][j]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef WVSPLITKRC_1KPASS
|
|
|
|
|
m0 += CuCount * WvPrGrp * YTILE / GrpsShrB;
|
|
|
|
|
m = (m0 + m1) % Mmod;
|
|
|
|
|
k_str = (m0 / Mmod) * kFit * kfitsPerRdc;
|
|
|
|
|
k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc;
|
|
|
|
|
if (k_str >= K) break;
|
|
|
|
|
kBase = 0;
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
|
|
|
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
|
|
|
|
int UNRL, int N, int GrpsShrB>
|
|
|
|
|
__global__ void wvSplitKrc_(const int actlN, const int K, const int M,
|
|
|
|
|
const int Bx, const int By, const scalar_t* B,
|
|
|
|
|
const scalar_t* __restrict__ A,
|
|
|
|
|
const scalar_t* __restrict__ BIAS, float* glbl,
|
|
|
|
|
// int* cntr,
|
|
|
|
|
scalar_t* C, const int CuCount){UNREACHABLE_CODE}
|
|
|
|
|
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
|
|
|
|
|
|
|
|
|
torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
|
|
|
|
const std::optional<at::Tensor>& in_bias,
|
|
|
|
|
const int64_t CuCount) {
|
|
|
|
|
auto M_in = in_a.size(0);
|
|
|
|
|
auto N_in = in_b.size(0);
|
|
|
|
|
auto K_in = in_a.size(1);
|
|
|
|
|
auto Bx_in =
|
|
|
|
|
(in_bias.has_value() && in_bias->numel() > 0)
|
|
|
|
|
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
|
|
|
|
|
: 1;
|
|
|
|
|
auto By_in = (in_bias.has_value() && in_bias->numel() > 0 &&
|
|
|
|
|
in_bias->sizes().size() == 2)
|
|
|
|
|
? in_bias->size(0)
|
|
|
|
|
: 1;
|
|
|
|
|
|
|
|
|
|
TORCH_CHECK(in_a.dtype() == in_b.dtype());
|
|
|
|
|
TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0");
|
|
|
|
|
TORCH_CHECK(in_a.dtype() == torch::kFloat16 ||
|
|
|
|
|
in_a.dtype() == torch::kBFloat16);
|
|
|
|
|
|
|
|
|
|
auto out_c = torch::empty(
|
|
|
|
|
{N_in, M_in},
|
|
|
|
|
torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device()));
|
|
|
|
|
|
|
|
|
|
auto N_p2 = 1U << (32 - __builtin_clz(N_in - 1));
|
|
|
|
|
auto axl_glbl = torch::empty(
|
|
|
|
|
{N_p2 + N_p2 / 4, M_in + M_in / 4},
|
|
|
|
|
torch::TensorOptions().dtype(torch::kFloat32).device(in_b.device()));
|
|
|
|
|
axl_glbl.zero_(); // disable for FAST_UNSAFE_RDC_INIT
|
|
|
|
|
|
|
|
|
|
dim3 grid(CuCount);
|
|
|
|
|
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
|
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
// const int max_lds_len = get_lds_size() / 2;
|
|
|
|
|
|
|
|
|
|
#define WVSPLITKrc(_WvPrGrp, _YTILE, _UNRL, _N, _GrpsShrB) \
|
|
|
|
|
{ \
|
|
|
|
|
dim3 block(64, _WvPrGrp); \
|
|
|
|
|
wvSplitKrc_<fptype, 64, _YTILE, _WvPrGrp, 8, _UNRL, _N, _GrpsShrB> \
|
|
|
|
|
<<<grid, block, 0, stream>>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \
|
|
|
|
|
biasf4, glbl, c, CuCount); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitKrc", [&] {
|
|
|
|
|
using fptype = typename scalar<scalar_t>::type;
|
|
|
|
|
fptype* af4 = reinterpret_cast<fptype*>(in_a.data_ptr());
|
|
|
|
|
const fptype* bf4 = reinterpret_cast<const fptype*>(in_b.data_ptr());
|
|
|
|
|
const fptype* biasf4 =
|
|
|
|
|
(in_bias.has_value() && in_bias->numel() > 0)
|
|
|
|
|
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
|
|
|
|
|
: nullptr;
|
|
|
|
|
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
|
|
|
|
auto glbl = axl_glbl.data_ptr<float>();
|
|
|
|
|
switch (N_p2) {
|
|
|
|
|
case 16:
|
|
|
|
|
WVSPLITKrc(4, 16, 1, 16, 1) break;
|
|
|
|
|
case 32:
|
|
|
|
|
WVSPLITKrc(4, 16, 1, 32, 2) break;
|
|
|
|
|
case 64:
|
|
|
|
|
WVSPLITKrc(4, 16, 1, 64, 2) break;
|
|
|
|
|
case 128:
|
|
|
|
|
WVSPLITKrc(4, 16, 1, 128, 4) break;
|
|
|
|
|
default:
|
|
|
|
|
throw std::runtime_error(
|
|
|
|
|
"Unsupported N value: " + std::to_string(M_in) + "," +
|
|
|
|
|
std::to_string(K_in) + "," + std::to_string(N_in));
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
return out_c;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
|
|
|
|
|
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
|
|
|
|
int A_CHUNK, int UNRL, int N>
|
|
|
|
|
@@ -1381,7 +1917,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
__shared__ fp8_t s[max_lds_len];
|
|
|
|
|
|
|
|
|
|
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
|
|
|
|
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
@@ -1570,7 +2106,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|
|
|
|
__shared__ fp8_t s[max_lds_len];
|
|
|
|
|
|
|
|
|
|
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
|
|
|
|
k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
|
|
|
|
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|