Adds padding and perf improvements to wvSplitK_fp8 (#33527)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
Hashem Hashemi
2026-02-05 14:16:02 -08:00
committed by GitHub
parent 42d5d705f9
commit d5c4800112
3 changed files with 169 additions and 229 deletions

View File

@@ -1899,8 +1899,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx,
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, const int M,
const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
@@ -1924,9 +1925,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
__shared__ fp8_t s[max_lds_len];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
}
asm volatile("s_waitcnt vmcnt(0)");
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
@@ -1934,37 +1940,24 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
floatx16 sum[N][YTILE];
float sA = *s_A;
float sB = *s_B;
while (m < M) {
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = {0.f};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
floatx16 sum[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
#pragma unroll
for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f};
#pragma unroll
for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f};
}
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
#pragma unroll
for (uint32_t y = 0; y < YTILE; ++y) {
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
}
}
@@ -1975,16 +1968,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (int n = 0; n < N; n++) {
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
if (k >= K) break;
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
@@ -2002,48 +1992,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
1); // row_shl8
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
1); // row_shl9
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
1); // row_shl10
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
1); // row_shl11
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
@@ -2051,19 +2020,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
if (threadIdx.x == 0) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
sum[n][y][0] *= sA * sB;
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
sum[n][y][0] += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][y][0] +=
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
sum[n][y][0] += __bfloat162float(biases[n][y]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]); // * sA * sB);
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
}
}
}
@@ -2074,9 +2047,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
const int M, const int Bx, const int By,
const fp8_t* B, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS,
scalar_t* C, const float* __restrict__ s_A,
const float* __restrict__ s_B,
@@ -2089,8 +2062,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx,
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, const int M,
const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) {
@@ -2113,9 +2087,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
__shared__ fp8_t s[max_lds_len];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
}
asm volatile("s_waitcnt vmcnt(0)");
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
@@ -2123,29 +2102,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
floatx16 sum[N][YTILE];
float sA = *s_A;
float sB = *s_B;
while (m < M) {
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = {0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
floatx16 sum[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
for (int y = 0; y < YTILE; ++y) {
if (y + m >= M) break; // To avoid mem access fault.
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
}
}
@@ -2156,20 +2129,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (int n = 0; n < N; n++) {
if (k_ + K * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
if (k_ + Kap * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
@@ -2187,48 +2156,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
1); // row_shl8
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
1); // row_shl9
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
1); // row_shl10
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
1); // row_shl11
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
@@ -2236,17 +2184,21 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
if (threadIdx.x == 0) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
sum[n][y][0] *= sA * sB;
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
sum[n][y][0] += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][y][0] +=
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
sum[n][y][0] += __bfloat162float(biases[n][y]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
}
@@ -2259,9 +2211,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
const int M, const int Bx, const int By,
const fp8_t* B, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
@@ -2270,17 +2222,18 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
}
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
auto M_in = in_a.size(0);
auto K_in = in_a.size(1);
auto N_in = in_b.size(0);
auto Kp_in = in_a.stride(0);
auto M_in = in_b.size(0);
auto K_in = in_b.size(1);
auto N_in = in_a.size(0);
auto Kap_in = in_a.stride(0);
auto Kbp_in = in_b.stride(0);
auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
@@ -2300,23 +2253,22 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
__wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
__wvPrGrp, CuCount); \
} \
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16)); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16)); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] {
@@ -2332,16 +2284,16 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
: nullptr;
switch (N_in) {
case 1:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1)
WVSPLITKQ(12, 2, 2, 2, 2, 1)
break;
case 2:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2)
WVSPLITKQ(12, 2, 2, 2, 2, 2)
break;
case 3:
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3)
WVSPLITKQ(8, 2, 2, 1, 1, 3)
break;
case 4:
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4)
WVSPLITKQ(4, 2, 2, 1, 1, 4)
break;
default:
throw std::runtime_error(