diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 8b8036258..dbc466f03 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -9,6 +9,10 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, const std::optional& in_bias, const int64_t CuCount); +torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, + const std::optional& in_bias, + const int64_t CuCount); + void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, const std::optional& in_bias, at::Tensor& out_c, const at::Tensor& scale_a, const at::Tensor& scale_b, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 8ebe55cef..50b6f6315 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -287,6 +287,11 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, V0 += (s.x + s.y); \ } +// To avoid LLVM silently upcasting to double +__device__ inline unsigned int min__(uint32_t a, uint32_t b) { + return min(a, b); +} + #if defined(__HIP__GFX9__) // TODO: Add NAVI support // This version targets cases where A[] fits LDS capacity template = min(K * N, max_lds_len)) break; + if (k_in >= min__(K * N, max_lds_len)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -633,11 +638,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Then the WG will move to another 8 K elements // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * N, max_lds_len); + for (uint32_t k = 0; k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * N, max_lds_len)) break; + if (k_in >= min__(K * N, max_lds_len)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -954,11 +959,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- #define PCML #ifndef PCML - for (uint32_t k = 0; k < min(K * N, max_lds_len); + for (uint32_t k = 0; k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (k_in >= min(K * N, max_lds_len)) break; + if (k_in >= min__(K * N, max_lds_len)) break; *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); } @@ -975,7 +980,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ? kFit : (kFit - kFit % TUC); // round up to multiple of TUC // if (kFit == 0) kFit = TUC; - kFit = min(kFit, K); + kFit = min__(kFit, K); float sum[N][YTILE]; scalar8 sum4[N][YTILE]; @@ -1251,6 +1256,7 @@ int mindiv(int N, int div1, int div2) { } for (int i = 12; i >= 0; i--) if (rnds[0] == rnds[i]) return (div2 - i); + return 0; } torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, @@ -1352,6 +1358,536 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, return out_c; } +#if defined(__gfx950__) // TODO: Add NAVI support + // This version targets big A[] cases, where it is much larger than LDS + // capacity + #define WVSPLITKRC_1KPASS +template + +__global__ void __launch_bounds__(WvPrGrp* THRDS) + __attribute__((amdgpu_waves_per_eu(1, 1))) + wvSplitKrc_(const int actlN, const int K, const int M, const int Bx, + const int By, const scalar_t* __restrict__ B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, float* glbl, scalar_t* C, + const int CuCount) { + // Use upper half of glbl buffer for atomic reduce counting + int* cntr = (int*)(&glbl[M * N]); + + constexpr int NTILE = 16; + constexpr int WVLDS_ = (NTILE * THRDS * A_CHUNK); + constexpr int APAD = 1; + constexpr int ASTRD = 64; + constexpr int BPAD = 1; + constexpr int BSTRD = 64; + constexpr int WVLDS = ((WVLDS_ + (WVLDS_ / BSTRD) * 4 * BPAD)); + + constexpr int max_lds_len = LDS_SIZE / 2; + + using scalar16 = + __attribute__((__vector_size__((A_CHUNK * 2) * sizeof(float)))) float; + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + unsigned int i[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + unsigned long l[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; + scalar8 h8; + }; + using big4 = __attribute__((__vector_size__(4 * sizeof(bigType)))) __bf16; + + __shared__ scalar_t stg[WvPrGrp * WVLDS / GrpsShrB]; + unsigned int* myStg = (unsigned int*)(&stg[WVLDS * (threadIdx.y / GrpsShrB)]); + __shared__ scalar_t s[max_lds_len - WvPrGrp * WVLDS / GrpsShrB]; + + #ifndef WVSPLITKRC_1KPASS + constexpr int TUC_ = (THRDS * UNRL * A_CHUNK); + // find biggest k size that fits padded into LDS + constexpr uint32_t kFit__ = (max_lds_len - WvPrGrp * WVLDS / GrpsShrB) / N; + constexpr uint32_t kFit_ = (kFit__ * ASTRD) / (APAD + ASTRD); + uint32_t kFit = kFit_ - (kFit_ % TUC_); + uint32_t kfitsPerRdc = (K + kFit - 1) / kFit; + + // find best k split to fill the CUs + if (((K + kfitsPerRdc * kFit - 1) / (kfitsPerRdc * kFit)) * numCuWithFullK <= + CuCount) + while (true) { + while (kFit > TUC_) { + uint32_t kFit_ = kFit - TUC_; + if (((K + (kfitsPerRdc * kFit_ - 1)) / (kfitsPerRdc * kFit_)) * + numCuWithFullK > + CuCount) + break; + kFit = kFit_; + } + if (((K + ((kfitsPerRdc - 1) * kFit - 1)) / ((kfitsPerRdc - 1) * kFit)) * + numCuWithFullK <= + CuCount) + kfitsPerRdc--; + else + break; + } + #else + int constexpr kFit = 512; + int constexpr kfitsPerRdc = 1; + #endif + + bool doRdc = (kfitsPerRdc * kFit < K); + uint32_t numCuWithFullK = + ((M + (WvPrGrp * YTILE / GrpsShrB) - 1) / (WvPrGrp * YTILE / GrpsShrB)); + uint32_t Mmod = numCuWithFullK * (WvPrGrp * YTILE / GrpsShrB); + + // given above k-split, find this wave's position + uint32_t kFitPdd = kFit + (kFit / ASTRD) * APAD; + uint32_t m0 = (blockIdx.x * WvPrGrp / GrpsShrB) * YTILE; + uint32_t m1 = ((threadIdx.y % WvPrGrp) / GrpsShrB) * YTILE; + uint32_t m = (m0 + m1) % Mmod; + const uint32_t k_str = (m0 / Mmod) * kFit * kfitsPerRdc; + uint32_t k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc; + const uint32_t k_rnd = (K + kFit * kfitsPerRdc - 1) / (kFit * kfitsPerRdc); + + scalar8 sum4[N / NTILE / GrpsShrB][1]; + bigType bigB_[YTILE / GrpsShrB][UNRL]; + const uint32_t bLoader = (threadIdx.y % GrpsShrB); + uint32_t kBase = 0; + if (k_str >= K) return; + if (m >= Mmod) return; + + bool noreloada = false; + constexpr bool FAST_UNSAFE_RDC_INIT = false; + + #ifdef WVSPLITKRC_1KPASS + // Early glbl init, B[] loading, if 1KPASS + if constexpr (FAST_UNSAFE_RDC_INIT) { + if (m + (threadIdx.x % 16) < M) + if (doRdc) + if (k_str == 0) { + int mindx = m + (threadIdx.x % 16); + int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr_ = mindx + M * nindx_ / 4; + __hip_atomic_store(&cntr[adr_], 0, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr = mindx + M * nindx; + __hip_atomic_store(&glbl[adr], 0, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + } + } + } + + // Load first B[] chunk + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k_str + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; + #pragma unroll + for (uint32_t y = 0; y < YTILE / GrpsShrB; y++) + bigB_[y][k2].h8 = (loadnt( + (scalar8*)(&B_[min__(y * GrpsShrB + bLoader + m, M - 1) * K]))); + } + { + #else + while (m < Mmod) { + #endif + + #ifndef WVSPLITKRC_1KPASS + if constexpr (FAST_UNSAFE_RDC_INIT) { + if (m + (threadIdx.x % 16) < M) + if (doRdc) + if (k_str == 0) { + int mindx = m + (threadIdx.x % 16); + int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr_ = mindx + M * nindx_ / 4; + __hip_atomic_store(&cntr[adr_], 0, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr = mindx + M * nindx; + __hip_atomic_store(&glbl[adr], 0, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + } + } + } + + #endif + + #ifndef WVSPLITKRC_1KPASS + for (uint32_t k1 = k_str; k1 < k_end; k1 += THRDS * A_CHUNK * UNRL) { + #else + const uint32_t k1 = k_str; + { + #endif + #ifndef WVSPLITKRC_1KPASS + const bool reloada = (!noreloada) && + ((k1 == k_str) || (k1 == k_str + kBase + kFit)) && + (k1 < k_end); + // load next chunk of A[] to LDS + if (reloada) { + if (k1 != k_str) kBase += kFit; + __syncthreads(); + #else + const bool reloada = (!noreloada) && + ((k1 == k_str) || (k1 == k_str + kBase + kFit)) && + (k1 < k_end); + if (reloada) { + #endif + constexpr int sprdN = 4; + const uint32_t thrd = ((threadIdx.y / sprdN) * THRDS + threadIdx.x); + + #ifndef WVSPLITKRC_1KPASS + #pragma unroll + for (int k = 0; k < kFit; k += THRDS * (WvPrGrp / sprdN) * A_CHUNK) { + #else + const unsigned int k = 0; + { + #endif + unsigned int kOff = k + (thrd * A_CHUNK); + unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff); + const unsigned int k_in = kOffcp + ((threadIdx.y % sprdN)) * K; + const unsigned int k_ot = kOff + ((threadIdx.y % sprdN)) * kFitPdd; + for (unsigned int n = 0; n < N / 2; n += sprdN) { + __builtin_amdgcn_global_load_lds((int*)(&A[k_in + n * K]), + (int*)(&s[(k_ot + n * kFitPdd)]), + 16, 0, 0); + if (((threadIdx.y % sprdN)) + n + N / 2 >= actlN) continue; + __builtin_amdgcn_global_load_lds( + (int*)(&A[k_in + (n + N / 2) * K]), + (int*)(&s[(k_ot + (n + N / 2) * kFitPdd)]), 16, 0, 0); + } + + // Stage loaded B[] to LDS for MFMA swizzling... + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + const bool oob_k = (k_ >= K); + for (uint32_t y = 0; y < YTILE / GrpsShrB; y++) { + uint32_t idx = threadIdx.x * 4 + + (y * GrpsShrB + bLoader) * ((THRDS + BPAD) * 4); + // zero out if oob + *((scalar8*)&myStg[idx]) = + (oob_k || (y * GrpsShrB + bLoader + m >= M)) + ? 0 + : bigB_[y][k2].h8; + } + } + } + } + } + + #ifndef WVSPLITKRC_1KPASS + // Fire load of next B[] chunk... + if ((k1 + THRDS * A_CHUNK * UNRL < k_end) && + (k1 + THRDS * A_CHUNK * UNRL < K)) + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + THRDS * A_CHUNK * UNRL + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; + #pragma unroll + for (uint32_t y = 0; y < YTILE / GrpsShrB; y++) + bigB_[y][k2].h8 = (loadnt( + (scalar8*)(&B_[min__(y * GrpsShrB + bLoader + m, M - 1) * K]))); + } + #endif + + // B[] staging is cooperative across GrpsShrB, so sync here before reading + // back + __syncthreads(); + + // read back B[] swizzled for MFMA... + bigType bigB[YTILE][UNRL]; + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + for (uint32_t y = 0; y < YTILE; y++) { + unsigned int idx = (threadIdx.x % YTILE) * ((THRDS + BPAD) * 4) + + (threadIdx.x / YTILE) * 4 + y * 16; + bigB[y][k2].h8 = *((scalar8*)&myStg[idx]); + } + } + + // rReadback A[] swizzled for MFMA... + bigType bigA[N / GrpsShrB][UNRL]; + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK - kBase - k_str; + #pragma unroll + for (uint32_t nt = 0; nt < N / GrpsShrB; nt += NTILE) + #pragma unroll + for (uint32_t n = 0; n < NTILE; n++) { + uint32_t idxa = (nt + (threadIdx.x % NTILE) + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) * + kFitPdd + + A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k; + bigA[nt + n][k2] = *((const bigType*)(&(s[idxa]))); + } + } + + // Do the MFMAs + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + #pragma unroll + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + if constexpr (std::is_same_v) { + sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16( + bigA[nt * NTILE + 0][k2].h4[0], bigB[0][k2].h4[0], + (k1 == k_str) ? ((scalar8){0}) : sum4[nt][0], 0, 0, 0); + sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16( + bigA[nt * NTILE + 0][k2].h4[1], bigB[0][k2].h4[1], sum4[nt][0], 0, + 0, 0); + } else { // bf16 + sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + bigA[nt * NTILE + 0][k2].h4[0], bigB[0][k2].h4[0], + (k1 == k_str) ? ((scalar8){0}) : sum4[nt][0], 0, 0, 0); + sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + bigA[nt * NTILE + 0][k2].h4[1], bigB[0][k2].h4[1], sum4[nt][0], 0, + 0, 0); + } + #pragma unroll + for (uint32_t j = 1; j < YTILE; j++) { + if constexpr (std::is_same_v) { + sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16( + bigA[nt * NTILE + j][k2].h4[0], bigB[j][k2].h4[0], sum4[nt][0], + 0, 0, 0); + sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16f16( + bigA[nt * NTILE + j][k2].h4[1], bigB[j][k2].h4[1], sum4[nt][0], + 0, 0, 0); + } else { // bf16 + sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + bigA[nt * NTILE + j][k2].h4[0], bigB[j][k2].h4[0], sum4[nt][0], + 0, 0, 0); + sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + bigA[nt * NTILE + j][k2].h4[1], bigB[j][k2].h4[1], sum4[nt][0], + 0, 0, 0); + } + } + } + } + } + + if (!doRdc) { + if (m + (threadIdx.x % 16) < M) { + scalar_t biases[N / NTILE / GrpsShrB][4] = {0}; + if (BIAS) + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int mindx = m + (threadIdx.x % 16); + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * M]; + } + } + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int mindx = m + (threadIdx.x % 16); + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr = mindx + M * nindx; + if constexpr (std::is_same_v) { + if (BIAS) sum4[nt][0][j] += __bfloat162float(biases[nt][j]); + C[adr] = __float2bfloat16(sum4[nt][0][j]); + } else { + if (BIAS) sum4[nt][0][j] += __half2float(biases[nt][j]); + C[adr] = __float2half(sum4[nt][0][j]); + } + } + } + } + } else { + if (m + (threadIdx.x % 16) < M) { + int my_cntr; + if (!BIAS) { + int mindx = m + (threadIdx.x % 16); + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr = mindx + M * nindx; + atomicAdd(&glbl[adr], sum4[nt][0][j]); + } + int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr_ = mindx + M * nindx_ / 4; + my_cntr = atomicAdd(&cntr[adr_], 1); + float vals[N / NTILE / GrpsShrB][4] = {}; + if (my_cntr + 1 == k_rnd) { + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr = mindx + M * nindx; + vals[nt][j] = glbl[adr]; + } + } + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + if (nindx >= actlN) break; + int adr = mindx + M * nindx; + if constexpr (std::is_same_v) { + C[adr] = __float2bfloat16(vals[nt][j]); + } else { + C[adr] = __float2half(vals[nt][j]); + } + } + } + } + } else { + int mindx = m + (threadIdx.x % 16); + scalar_t biases[N / NTILE / GrpsShrB][4] = {}; + // Atomic add the output, read biases + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr = mindx + M * nindx; + atomicAdd(&glbl[adr], sum4[nt][0][j]); + biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * M]; + } + int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr_ = mindx + M * nindx_ / 4; + // Update the complete counter + my_cntr = atomicAdd(&cntr[adr_], 1); + float vals[N / NTILE / GrpsShrB][4] = {}; + // If we're the last k-shard, read back the value and convert... + if (my_cntr + 1 == k_rnd) { + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr = mindx + M * nindx; + vals[nt][j] = glbl[adr]; + } + } + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + if (nindx >= actlN) break; + int adr = mindx + M * nindx; + if constexpr (std::is_same_v) { + vals[nt][j] += __bfloat162float(biases[nt][j]); + C[adr] = __float2bfloat16(vals[nt][j]); + } else { + vals[nt][j] += __half2float(biases[nt][j]); + C[adr] = __float2half(vals[nt][j]); + } + } + } + } + } + } + + #ifndef WVSPLITKRC_1KPASS + m0 += CuCount * WvPrGrp * YTILE / GrpsShrB; + m = (m0 + m1) % Mmod; + k_str = (m0 / Mmod) * kFit * kfitsPerRdc; + k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc; + if (k_str >= K) break; + kBase = 0; + #endif + } +} +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support +template +__global__ void wvSplitKrc_(const int actlN, const int K, const int M, + const int Bx, const int By, const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, float* glbl, + // int* cntr, + scalar_t* C, const int CuCount){UNREACHABLE_CODE} +#endif // defined(__HIP__GFX9__) TODO: Add NAVI support + +torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, + const std::optional& in_bias, + const int64_t CuCount) { + auto M_in = in_a.size(0); + auto N_in = in_b.size(0); + auto K_in = in_a.size(1); + auto Bx_in = + (in_bias.has_value() && in_bias->numel() > 0) + ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) + : 1; + auto By_in = (in_bias.has_value() && in_bias->numel() > 0 && + in_bias->sizes().size() == 2) + ? in_bias->size(0) + : 1; + + TORCH_CHECK(in_a.dtype() == in_b.dtype()); + TORCH_CHECK(K_in % 8 == 0, "k % 8 == 0"); + TORCH_CHECK(in_a.dtype() == torch::kFloat16 || + in_a.dtype() == torch::kBFloat16); + + auto out_c = torch::empty( + {N_in, M_in}, + torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + + auto N_p2 = 1U << (32 - __builtin_clz(N_in - 1)); + auto axl_glbl = torch::empty( + {N_p2 + N_p2 / 4, M_in + M_in / 4}, + torch::TensorOptions().dtype(torch::kFloat32).device(in_b.device())); + axl_glbl.zero_(); // disable for FAST_UNSAFE_RDC_INIT + + dim3 grid(CuCount); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // const int max_lds_len = get_lds_size() / 2; + +#define WVSPLITKrc(_WvPrGrp, _YTILE, _UNRL, _N, _GrpsShrB) \ + { \ + dim3 block(64, _WvPrGrp); \ + wvSplitKrc_ \ + <<>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \ + biasf4, glbl, c, CuCount); \ + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitKrc", [&] { + using fptype = typename scalar::type; + fptype* af4 = reinterpret_cast(in_a.data_ptr()); + const fptype* bf4 = reinterpret_cast(in_b.data_ptr()); + const fptype* biasf4 = + (in_bias.has_value() && in_bias->numel() > 0) + ? reinterpret_cast(in_bias->data_ptr()) + : nullptr; + fptype* c = reinterpret_cast(out_c.data_ptr()); + auto glbl = axl_glbl.data_ptr(); + switch (N_p2) { + case 16: + WVSPLITKrc(4, 16, 1, 16, 1) break; + case 32: + WVSPLITKrc(4, 16, 1, 32, 2) break; + case 64: + WVSPLITKrc(4, 16, 1, 64, 2) break; + case 128: + WVSPLITKrc(4, 16, 1, 128, 4) break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + }); + return out_c; +} + #if defined(__HIP__MI3XX__) // TODO: Add NAVI support template @@ -1381,7 +1917,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) __shared__ fp8_t s[max_lds_len]; for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; - k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { + k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { *((bigType*)(&s[k])) = *((bigType*)(&A[k])); } __syncthreads(); @@ -1570,7 +2106,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) __shared__ fp8_t s[max_lds_len]; for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; - k < min(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { + k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { *((bigType*)(&s[k])) = *((bigType*)(&A[k])); } __syncthreads(); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 518486b1c..b0b44964c 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -26,6 +26,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { "Tensor"); rocm_ops.impl("wvSplitK", torch::kCUDA, &wvSplitK); + // Custom gemm op for skinny matrix-matrix multiplication + rocm_ops.def( + "wvSplitKrc(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> " + "Tensor"); + rocm_ops.impl("wvSplitKrc", torch::kCUDA, &wvSplitKrc); + // wvSplitK for fp8 rocm_ops.def( "wvSplitKQ(Tensor in_a, Tensor in_b, Tensor? in_bias, Tensor! out_c, " diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 15ff6d536..1f6464e21 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -8,9 +8,11 @@ import torch import vllm._custom_ops as ops from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant from vllm.platforms import current_platform +from vllm.platforms.rocm import on_gfx950 from vllm.utils.platform_utils import get_cu_count DTYPES = [torch.bfloat16, torch.float16] +BIAS_MODES = [0, 1, 2] # Specific (N, K, M) combinations for targeted testing NKM_FACTORS_LLMM1 = [ # Small, medium, large cases @@ -43,6 +45,31 @@ NKM_FACTORS_WVSPLITK = [ (4, 256, 8), ] +NKM_FACTORS_WVSPLITKRC = [ + (16, 2880, 128), + (16, 2880, 640), + (17, 2880, 128), + (17, 2880, 640), + (25, 2880, 128), + (25, 2880, 640), + (31, 2880, 128), + (31, 2880, 640), + (32, 2880, 128), + (32, 2880, 640), + (40, 2880, 128), + (40, 2880, 640), + (60, 2880, 128), + (60, 2880, 640), + (64, 2880, 128), + (64, 2880, 640), + (81, 2880, 128), + (81, 2880, 640), + (98, 2880, 128), + (98, 2880, 640), + (128, 2880, 128), + (128, 2880, 640), +] + NKM_FACTORS_WVSPLITK_FP8 = [ # FP8-specific cases with K % 16 == 0 (1, 16, 16), @@ -60,6 +87,32 @@ NKM_FACTORS_WVSPLITK_FP8 = [ SEEDS = [0] +@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("bias_mode", BIAS_MODES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") +@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950") +def test_rocm_wvsplitkrc_kernel(n, k, m, dtype, seed, bias_mode): + torch.manual_seed(seed) + cu_count = get_cu_count() + + xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas + A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier + + BIAS = None + if bias_mode == 1: + BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 + elif bias_mode == 2: + BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5 + + ref_out = torch.nn.functional.linear(A, B, BIAS) + out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS) + + assert torch.allclose(out, ref_out, rtol=0.01) + + @pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7ff23e968..267b242d5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2072,6 +2072,12 @@ def wvSplitK( return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) +def wvSplitKrc( + a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None +) -> torch.Tensor: + return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count) + + def wvSplitKQ( a: torch.Tensor, b: torch.Tensor, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 4b7ba2eed..44a52b252 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -129,12 +129,32 @@ def use_aiter_triton_gemm(n, m, k, dtype): def rocm_unquantized_gemm_impl( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: - from vllm.platforms.rocm import on_gfx9 + from vllm.platforms.rocm import on_gfx9, on_gfx950 n = x.numel() / x.size(-1) m = weight.shape[0] k = weight.shape[1] + import math + + use_skinny_reduce_counting = ( + envs.VLLM_ROCM_USE_SKINNY_GEMM + and on_gfx950() + and x.dtype in [torch.float16, torch.bfloat16] + and ( + n >= 16 + and n <= 128 + and k > 512 + and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count() + ) + # k == 2880 and (m == 640 or m == 128)) + ) + if use_skinny_reduce_counting: + cu_count = get_cu_count() + x_view = x.reshape(-1, x.size(-1)) + out = ops.wvSplitKrc(weight, x_view, cu_count, bias) + return out.reshape(*x.shape[:-1], weight.shape[0]) + if use_aiter_triton_gemm(n, m, k, x.dtype): from aiter.ops.triton.gemm_a16w16 import gemm_a16w16