Add padding support to wvSplitK solution for skinny GEMMs (#33762)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
Hashem Hashemi
2026-02-28 01:02:05 -08:00
committed by GitHub
parent 1e69c04887
commit 7600642eae
3 changed files with 289 additions and 444 deletions

View File

@@ -304,8 +304,9 @@ __device__ inline unsigned int min__(uint32_t a, uint32_t b) {
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap, const int M,
const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
@@ -314,7 +315,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
@@ -346,13 +346,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min__(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
}
__syncthreads();
@@ -360,9 +360,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
@@ -386,44 +383,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
float sum[N][YTILE] = {};
scalar8 sum4[N][YTILE] = {};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
// for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
for (int y = 0; y < YTILE; y++)
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K])));
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
@@ -432,33 +405,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
@@ -466,46 +426,44 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
}
}
__builtin_amdgcn_sched_barrier(0);
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
1); // row_shr8
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
1); // row_shr4
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
1); // row_shr2
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
1); // row_shr1
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
1); // ROW_BCAST15
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
1); // ROW_BCAST31
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if constexpr (std::is_same_v<scalar_t, half>) {
sum[n][y] += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
sum[n][y] += __bfloat162float(biases[n][y]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
}
}
}
@@ -514,45 +472,43 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
/*float accm1 = 0;
for (int i=0; i<64; i++)
accm1 += __shfl(sum4[n][y][i%4], i);
sum4[n][y][0] = accm1;*/
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
1); // row_shl4
accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
1); // row_shl8
accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
1); // row_shr15
accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
1); // ROW_BCAST15
accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
1); // ROW_BCAST31
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
for (int y = 0; y < YTILE; y++) {
sum4[n][y][0] += __bfloat162float(biases[n][y]);
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
}
}
}
@@ -563,8 +519,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
const int By, const scalar_t* B,
__global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap,
const int M, const int Bx, const int By,
const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
@@ -577,8 +534,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
wvSplitK_hf_(const int K, const int Kbp, const int Kap, const int M,
const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
@@ -601,13 +559,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
scalar8 h8;
};
//----------------------------------------------------
// Reserving 64 KB of LDS to have 1 WG / CU
// Goal is to bring the activation matrix A to the LDS
// and use it across the lifetime of the work group
// TODO: When activation matrix is larger than 64 KB
// then this is not going to work!
//----------------------------------------------------
__shared__ scalar_t s[max_lds_len];
//----------------------------------------------------
@@ -618,12 +569,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
commitColumn[i] = 1;
}
//----------------------------------------------------
// Indexing function into the column of weight matrix B
// Algorithm does 64 lane k-splitting / wave and uses
// WG ID and Thread ID to find the index.
//----------------------------------------------------
// int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp);
uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE;
// Check whether there will be fragmentation!
@@ -636,91 +581,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
m = startColumn;
}
//----------------------------------------------------
// Fetch the activation matrix to LDS
// Loop iteration:
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements
// - Each WG will fetch 512 * 16 => 8K elements
// - Then the WG will move to another 8 K elements
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min__(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
}
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
// There are 16 waves per WG, and hence, each WG is
// working on 16 columns of weight matrix. Moreover,
// we tile in column direction by YTILE, so when YTILE=1
// the above math is right, however, when YTILE=2 then
// each wave will be working on 2 columns and WG will
// be working on 32 columns.
//
// Top level loop that makes WGs persistent!
// - WGs iterates across columns of weight matrix
// - Each wave within WG works on a given column(s)
// - After completing first set of columns, WGs start
// working on the next set of available columns
//----------------------------------------------------
while (m < M) {
//----------------------------------------------------
// 'sum' accumulates the matrix A x B computation
// split across 64 lanes.
//
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
float sum[N][YTILE] = {};
scalar8 sum4[N][YTILE] = {};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
// - Each wave will fetch 64*8=> 512 elements (1024B)
// - YTILE represents the number of column being serviced
// by wave
// - Loop for fetching weight matrix (B) are unrolled
//
// Fetch activation matrix A from LDS
// - Loop for fetching activation matrix (A) are unrolled
//
// Finally, do the matrix multiplication in an unrolled
// fashion. This provides lot of food for compiler
// scheduling.
//
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
for (int y = 0; y < YTILE; y++)
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
@@ -729,36 +617,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
if (k_ + K * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
if (k_ + Kap * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
@@ -773,40 +648,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
1); // row_shr8
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
1); // row_shr4
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
1); // row_shr2
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
1); // row_shr1
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
1); // ROW_BCAST15
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
1); // ROW_BCAST31
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
sum[n][y] += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
sum[n][y] += __bfloat162float(biases[n][y]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
}
}
}
@@ -819,44 +692,39 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
1); // row_shl4
accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
1); // row_shl8
accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
1); // row_shr15
accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
1); // ROW_BCAST15
accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
1); // ROW_BCAST31
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
sum4[n][y][0] += __bfloat162float(biases[n][y]);
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
}
}
}
@@ -880,9 +748,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
__global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap,
const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
UNREACHABLE_CODE
@@ -894,8 +762,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const int Bx,
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By,
const scalar_t* B, const scalar_t* __restrict__ A,
wvSplitK_hf_big_(const int K, const int Kbp, const int Kap, const int M,
const int Bx, const int By, const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
constexpr int max_lds_len = LDS_SIZE / 2;
@@ -966,13 +835,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
#define PCML
#ifndef PCML
for (uint32_t k = 0; k < min__(K * N, max_lds_len);
k += THRDS * WvPrGrp * A_CHUNK) {
uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (k_in >= min__(K * N, max_lds_len)) break;
*((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in]));
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
}
__syncthreads();
#endif
@@ -987,10 +856,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
? kFit
: (kFit - kFit % TUC); // round up to multiple of TUC
// if (kFit == 0) kFit = TUC;
kFit = min__(kFit, K);
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
kFit = min__(kFit, Kap);
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
@@ -1021,15 +887,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// YTILE represents how many column of weight matrix
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
float sum[N][YTILE] = {};
scalar8 sum4[N][YTILE] = {};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
//----------------------------------------------------
// Fetch weight matrix B in interleaved K-split!
// - Each thread (lane) is fetching 8 elements (A_Chunk)
@@ -1048,18 +908,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// TODO: Logic below will only work when K is multiple of 8
//----------------------------------------------------
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
#ifdef PCML
if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS
if (k1 != 0) kBase += kFit;
__syncthreads();
for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) {
uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK);
if (kBase + kOff >= K) break;
if (kBase + kOff >= Kap) break;
if (kOff >= kFit) break;
for (uint32_t n = 0; n < N; n++) {
uint32_t k_in = kBase + n * K + kOff;
uint32_t k_in = kBase + n * Kap + kOff;
uint32_t k_ot = n * kFit + kOff;
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k_in]), (int*)(&s[k_ot]),
16, 0, 0);
#else
*((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in]));
#endif
}
}
__syncthreads();
@@ -1072,11 +940,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)];
for (int y = 0; y < YTILE; y++)
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
@@ -1085,17 +951,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
// Fetch A activation matrix in interleaved fashion from LDS or memory
for (int n = 0; n < N; n++) {
#ifdef PCML
bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n])));
#else
if (k_ + K * n < 32 * 1024)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
if (k_ + Kap * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
#endif
}
}
@@ -1103,22 +966,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
@@ -1141,40 +995,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf,
1); // row_shr8
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf,
1); // row_shr4
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf,
1); // row_shr2
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
1); // row_shr1
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
1); // ROW_BCAST15
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
1); // ROW_BCAST31
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]);
sum[n][y] += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][i] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
sum[n][y] += __bfloat162float(biases[n][y]);
}
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y]);
}
}
}
@@ -1185,42 +1037,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for (int y = 0; y < YTILE; y++) {
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf,
1); // row_shl4
accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf,
1); // row_shl8
accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf,
1); // row_shr15
accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf,
1); // ROW_BCAST15
accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf,
1); // ROW_BCAST31
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i]) {
if (BIAS)
sum4[n][i][0] +=
__bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
for (int y = 0; y < YTILE; y++) {
if (commitColumn[y]) {
sum4[n][y][0] += __bfloat162float(biases[n][y]);
C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]);
}
}
}
@@ -1244,8 +1092,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
int UNRL, int N>
__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
const int By, const scalar_t* B,
__global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap,
const int M, const int Bx, const int By,
const scalar_t* B,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
@@ -1272,6 +1121,8 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
auto M_in = in_a.size(0);
auto K_in = in_a.size(1);
auto N_in = in_b.size(0);
auto Kap_in = in_a.stride(0);
auto Kbp_in = in_b.stride(0);
auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
@@ -1296,27 +1147,30 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size() / 2;
#define WVSPLITK(_YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else if (K_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
biasf4, c, __wvPrGrp, CuCount); \
#define WVSPLITK(_YTILE, _UNRL, _N) \
{ \
dim3 block(64, 16); \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
if ((Kbp_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
else if (Kbp_in * N_in <= max_lds_len * 1.2) \
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
else \
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
CuCount); \
}
#define WVSPLIT_TILE(_sYT, __N) \
{ \
bool fit_lds = (K_in * N_in <= max_lds_len); \
bool fit_lds = (Kbp_in * N_in <= max_lds_len); \
if (_sYT <= 1) \
WVSPLITK(1, 4, __N) \
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \

View File

@@ -30,15 +30,22 @@ NKM_FACTORS_LLMM1 = [
NKM_FACTORS_WVSPLITK = [
# Different batch sizes with key dimensions
(1, 16, 16),
(1, 32, 16),
(1, 64, 64),
(2, 256, 256),
(3, 1024, 1024),
(4, 4096, 4096),
(4, 4096, 4096 + 1),
(4, 4096 + 16, 4096),
(4, 4096 + 16, 4096 + 1),
# Extended K values
(1, 9216, 512),
(2, 10240, 1024),
(4, 16384, 8192),
(4, 16384 * 2, 8192),
(4, 16384 * 2, 8192 + 1),
(4, 16384 * 2 + 16, 8192),
(4, 16384 * 2 + 16, 8192 + 1),
# Minimum M constraint validation (m >= 8)
(1, 64, 8),
(2, 128, 8),
@@ -180,59 +187,44 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("bias_mode", BIAS_MODES)
@pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("padded_b", [False, True])
def test_rocm_wvsplitk_kernel(
xnorm, n, k, m, dtype, seed, bias_mode, padded_a, padded_b
):
torch.manual_seed(seed)
cu_count = num_compute_units()
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
xavier = (
math.sqrt(2 / k) if xnorm else 1
) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier
ref_out = torch.nn.functional.linear(A, B)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
BIAS = None
if bias_mode == 1:
BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
elif bias_mode == 2:
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
if padded_a:
A = pad_fp8(A)
if padded_b:
B = pad_fp8(B)
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = num_compute_units()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2)
if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
@pytest.mark.parametrize("xnorm", [False, True])

View File

@@ -191,7 +191,6 @@ def rocm_unquantized_gemm_impl(
and on_gfx9()
and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0
and x.is_contiguous()
)
if use_skinny is not True: