Adds padding and perf improvements to wvSplitK_fp8 (#33527)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
@@ -1899,8 +1899,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx,
|
||||
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, const int M,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B, const int _WvPrGrp,
|
||||
@@ -1924,9 +1925,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
__shared__ fp8_t s[max_lds_len];
|
||||
|
||||
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
#if defined(__gfx950__)
|
||||
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
|
||||
#else
|
||||
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||
#endif
|
||||
}
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.y >= _WvPrGrp) return;
|
||||
@@ -1934,37 +1940,24 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
||||
|
||||
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
|
||||
floatx16 sum[N][YTILE];
|
||||
float sA = *s_A;
|
||||
float sB = *s_B;
|
||||
|
||||
while (m < M) {
|
||||
for (int i = 0; i < YTILE; i++)
|
||||
for (int n = 0; n < N; n++) sum[n][i] = {0.f};
|
||||
|
||||
bigType bigA[N][UNRL];
|
||||
bigType bigB[YTILE][UNRL];
|
||||
|
||||
floatx16 sum[N][YTILE] = {};
|
||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f};
|
||||
#pragma unroll
|
||||
for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f};
|
||||
}
|
||||
bigType bigA[N][UNRL] = {};
|
||||
bigType bigB[YTILE][UNRL];
|
||||
|
||||
// Fetch the weight matrix from memory!
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
|
||||
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
|
||||
const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||
#pragma unroll
|
||||
for (uint32_t y = 0; y < YTILE; ++y) {
|
||||
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
|
||||
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1975,16 +1968,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
for (int n = 0; n < N; n++) {
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
|
||||
}
|
||||
}
|
||||
|
||||
// Do the matrix multiplication in interleaved manner
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||
if (k >= K) break;
|
||||
|
||||
for (uint32_t n = 0; n < N; n++) {
|
||||
for (int i = 0; i < A_CHUNK; i += 8) {
|
||||
for (int y = 0; y < YTILE; ++y) {
|
||||
@@ -2002,48 +1992,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
float accm0 = sum[n][y][0];
|
||||
float accm16 = sum[n][y][8];
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
|
||||
1); // row_shl1
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
|
||||
1); // row_shl2
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
|
||||
1); // row_shl3
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
|
||||
1); // row_shl8
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
|
||||
1); // row_shl9
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
|
||||
1); // row_shl10
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
|
||||
1); // row_shl11
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
|
||||
accm0 += __shfl(accm0, 36);
|
||||
accm16 += __shfl(accm16, 52);
|
||||
sum[n][y][0] = accm0 + __shfl(accm16, 16);
|
||||
@@ -2051,19 +2020,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
scalar_t biases[N][YTILE] = {};
|
||||
if (BIAS)
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
|
||||
}
|
||||
}
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
sum[n][y][0] *= sA * sB;
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
sum[n][y][0] += __half2float(biases[n][y]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] +=
|
||||
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
sum[n][y][0] += __bfloat162float(biases[n][y]);
|
||||
}
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]); // * sA * sB);
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2074,9 +2047,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
|
||||
const int M, const int Bx, const int By,
|
||||
const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS,
|
||||
scalar_t* C, const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B,
|
||||
@@ -2089,8 +2062,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx,
|
||||
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, const int M,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A, const float* __restrict__ s_B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
@@ -2113,9 +2087,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
__shared__ fp8_t s[max_lds_len];
|
||||
|
||||
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
#if defined(__gfx950__)
|
||||
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
|
||||
#else
|
||||
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||
#endif
|
||||
}
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.y >= _WvPrGrp) return;
|
||||
@@ -2123,29 +2102,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
||||
|
||||
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
|
||||
floatx16 sum[N][YTILE];
|
||||
float sA = *s_A;
|
||||
float sB = *s_B;
|
||||
|
||||
while (m < M) {
|
||||
for (int i = 0; i < YTILE; i++)
|
||||
for (int n = 0; n < N; n++) sum[n][i] = {0};
|
||||
|
||||
bigType bigA[N][UNRL];
|
||||
bigType bigB[YTILE][UNRL];
|
||||
|
||||
floatx16 sum[N][YTILE] = {};
|
||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||
bigType bigA[N][UNRL] = {};
|
||||
bigType bigB[YTILE][UNRL];
|
||||
|
||||
// Fetch the weight matrix from memory!
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
|
||||
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
|
||||
const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||
for (int y = 0; y < YTILE; ++y) {
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
|
||||
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2156,20 +2129,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
for (int n = 0; n < N; n++) {
|
||||
if (k_ + K * n < max_lds_len)
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
||||
if (k_ + Kap * n < max_lds_len)
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
|
||||
else
|
||||
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
|
||||
bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
|
||||
}
|
||||
}
|
||||
|
||||
// Do the matrix multiplication in interleaved manner
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
|
||||
for (uint32_t n = 0; n < N; n++) {
|
||||
for (int i = 0; i < A_CHUNK; i += 8) {
|
||||
for (int y = 0; y < YTILE; ++y) {
|
||||
@@ -2187,48 +2156,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
float accm0 = sum[n][y][0];
|
||||
float accm16 = sum[n][y][8];
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
|
||||
1); // row_shl1
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
|
||||
1); // row_shl2
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
|
||||
1); // row_shl3
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
|
||||
1); // row_shl8
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
|
||||
1); // row_shl9
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
|
||||
1); // row_shl10
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
|
||||
1); // row_shl11
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
|
||||
accm0 += __shfl(accm0, 36);
|
||||
accm16 += __shfl(accm16, 52);
|
||||
sum[n][y][0] = accm0 + __shfl(accm16, 16);
|
||||
@@ -2236,17 +2184,21 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
scalar_t biases[N][YTILE] = {};
|
||||
if (BIAS)
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
|
||||
}
|
||||
}
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
sum[n][y][0] *= sA * sB;
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
sum[n][y][0] += __half2float(biases[n][y]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] +=
|
||||
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
sum[n][y][0] += __bfloat162float(biases[n][y]);
|
||||
}
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
|
||||
}
|
||||
@@ -2259,9 +2211,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
|
||||
const int M, const int Bx, const int By,
|
||||
const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B, const int _WvPrGrp,
|
||||
@@ -2270,17 +2222,18 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
||||
}
|
||||
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
|
||||
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
|
||||
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
|
||||
const at::Tensor& scale_a, const at::Tensor& scale_b,
|
||||
const int64_t CuCount) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
auto M_in = in_a.size(0);
|
||||
auto K_in = in_a.size(1);
|
||||
auto N_in = in_b.size(0);
|
||||
auto Kp_in = in_a.stride(0);
|
||||
auto M_in = in_b.size(0);
|
||||
auto K_in = in_b.size(1);
|
||||
auto N_in = in_a.size(0);
|
||||
auto Kap_in = in_a.stride(0);
|
||||
auto Kbp_in = in_b.stride(0);
|
||||
auto Bx_in =
|
||||
(in_bias.has_value() && in_bias->numel() > 0)
|
||||
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
|
||||
@@ -2300,23 +2253,22 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const int max_lds_len = get_lds_size();
|
||||
|
||||
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
||||
_N) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
|
||||
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
|
||||
__wvPrGrp, CuCount); \
|
||||
} else { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
||||
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
|
||||
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
|
||||
__wvPrGrp, CuCount); \
|
||||
} \
|
||||
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16)); \
|
||||
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
|
||||
s_a, s_b, __wvPrGrp, CuCount); \
|
||||
} else { \
|
||||
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16)); \
|
||||
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
|
||||
s_a, s_b, __wvPrGrp, CuCount); \
|
||||
} \
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] {
|
||||
@@ -2332,16 +2284,16 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
: nullptr;
|
||||
switch (N_in) {
|
||||
case 1:
|
||||
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1)
|
||||
WVSPLITKQ(12, 2, 2, 2, 2, 1)
|
||||
break;
|
||||
case 2:
|
||||
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2)
|
||||
WVSPLITKQ(12, 2, 2, 2, 2, 2)
|
||||
break;
|
||||
case 3:
|
||||
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3)
|
||||
WVSPLITKQ(8, 2, 2, 1, 1, 3)
|
||||
break;
|
||||
case 4:
|
||||
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4)
|
||||
WVSPLITKQ(4, 2, 2, 1, 1, 4)
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
Reference in New Issue
Block a user