Adds padding and perf improvements to wvSplitK_fp8 (#33527)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
@@ -1899,8 +1899,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx,
|
||||
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, const int M,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B, const int _WvPrGrp,
|
||||
@@ -1924,9 +1925,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
__shared__ fp8_t s[max_lds_len];
|
||||
|
||||
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
#if defined(__gfx950__)
|
||||
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
|
||||
#else
|
||||
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||
#endif
|
||||
}
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.y >= _WvPrGrp) return;
|
||||
@@ -1934,37 +1940,24 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
||||
|
||||
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
|
||||
floatx16 sum[N][YTILE];
|
||||
float sA = *s_A;
|
||||
float sB = *s_B;
|
||||
|
||||
while (m < M) {
|
||||
for (int i = 0; i < YTILE; i++)
|
||||
for (int n = 0; n < N; n++) sum[n][i] = {0.f};
|
||||
|
||||
bigType bigA[N][UNRL];
|
||||
bigType bigB[YTILE][UNRL];
|
||||
|
||||
floatx16 sum[N][YTILE] = {};
|
||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
#pragma unroll
|
||||
for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f};
|
||||
#pragma unroll
|
||||
for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f};
|
||||
}
|
||||
bigType bigA[N][UNRL] = {};
|
||||
bigType bigB[YTILE][UNRL];
|
||||
|
||||
// Fetch the weight matrix from memory!
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
|
||||
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
|
||||
const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||
#pragma unroll
|
||||
for (uint32_t y = 0; y < YTILE; ++y) {
|
||||
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
|
||||
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1975,16 +1968,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
for (int n = 0; n < N; n++) {
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
|
||||
}
|
||||
}
|
||||
|
||||
// Do the matrix multiplication in interleaved manner
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||
if (k >= K) break;
|
||||
|
||||
for (uint32_t n = 0; n < N; n++) {
|
||||
for (int i = 0; i < A_CHUNK; i += 8) {
|
||||
for (int y = 0; y < YTILE; ++y) {
|
||||
@@ -2002,48 +1992,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
float accm0 = sum[n][y][0];
|
||||
float accm16 = sum[n][y][8];
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
|
||||
1); // row_shl1
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
|
||||
1); // row_shl2
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
|
||||
1); // row_shl3
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
|
||||
1); // row_shl8
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
|
||||
1); // row_shl9
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
|
||||
1); // row_shl10
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
|
||||
1); // row_shl11
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
|
||||
accm0 += __shfl(accm0, 36);
|
||||
accm16 += __shfl(accm16, 52);
|
||||
sum[n][y][0] = accm0 + __shfl(accm16, 16);
|
||||
@@ -2051,19 +2020,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
scalar_t biases[N][YTILE] = {};
|
||||
if (BIAS)
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
|
||||
}
|
||||
}
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
sum[n][y][0] *= sA * sB;
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
sum[n][y][0] += __half2float(biases[n][y]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] +=
|
||||
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
sum[n][y][0] += __bfloat162float(biases[n][y]);
|
||||
}
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]); // * sA * sB);
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2074,9 +2047,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
|
||||
const int M, const int Bx, const int By,
|
||||
const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS,
|
||||
scalar_t* C, const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B,
|
||||
@@ -2089,8 +2062,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx,
|
||||
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, const int M,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A, const float* __restrict__ s_B,
|
||||
const int _WvPrGrp, const int CuCount) {
|
||||
@@ -2113,9 +2087,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
__shared__ fp8_t s[max_lds_len];
|
||||
|
||||
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
|
||||
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
|
||||
#if defined(__gfx950__)
|
||||
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
|
||||
#else
|
||||
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
|
||||
#endif
|
||||
}
|
||||
asm volatile("s_waitcnt vmcnt(0)");
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.y >= _WvPrGrp) return;
|
||||
@@ -2123,29 +2102,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
||||
|
||||
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
|
||||
floatx16 sum[N][YTILE];
|
||||
float sA = *s_A;
|
||||
float sB = *s_B;
|
||||
|
||||
while (m < M) {
|
||||
for (int i = 0; i < YTILE; i++)
|
||||
for (int n = 0; n < N; n++) sum[n][i] = {0};
|
||||
|
||||
bigType bigA[N][UNRL];
|
||||
bigType bigB[YTILE][UNRL];
|
||||
|
||||
floatx16 sum[N][YTILE] = {};
|
||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||
bigType bigA[N][UNRL] = {};
|
||||
bigType bigB[YTILE][UNRL];
|
||||
|
||||
// Fetch the weight matrix from memory!
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
|
||||
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
|
||||
const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
|
||||
for (int y = 0; y < YTILE; ++y) {
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
|
||||
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2156,20 +2129,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
for (int n = 0; n < N; n++) {
|
||||
if (k_ + K * n < max_lds_len)
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
|
||||
if (k_ + Kap * n < max_lds_len)
|
||||
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
|
||||
else
|
||||
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
|
||||
bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
|
||||
}
|
||||
}
|
||||
|
||||
// Do the matrix multiplication in interleaved manner
|
||||
#pragma unroll
|
||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
|
||||
uint32_t k_ = k + threadIdx.x * A_CHUNK;
|
||||
if (k_ >= K) break;
|
||||
|
||||
for (uint32_t n = 0; n < N; n++) {
|
||||
for (int i = 0; i < A_CHUNK; i += 8) {
|
||||
for (int y = 0; y < YTILE; ++y) {
|
||||
@@ -2187,48 +2156,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
float accm0 = sum[n][y][0];
|
||||
float accm16 = sum[n][y][8];
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
|
||||
: "=v"(accm0)
|
||||
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
|
||||
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
|
||||
: "=v"(accm16)
|
||||
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
|
||||
1); // row_shl1
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
|
||||
1); // row_shl2
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
|
||||
1); // row_shl3
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
|
||||
1); // row_shl8
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
|
||||
1); // row_shl9
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
|
||||
1); // row_shl10
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
|
||||
1); // row_shl11
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
|
||||
accm0 += __shfl(accm0, 36);
|
||||
accm16 += __shfl(accm16, 52);
|
||||
sum[n][y][0] = accm0 + __shfl(accm16, 16);
|
||||
@@ -2236,17 +2184,21 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
scalar_t biases[N][YTILE] = {};
|
||||
if (BIAS)
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
|
||||
}
|
||||
}
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
if (y + m >= M) break; // To avoid mem access fault.
|
||||
sum[n][y][0] *= sA * sB;
|
||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
sum[n][y][0] += __half2float(biases[n][y]);
|
||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||
if (BIAS)
|
||||
sum[n][y][0] +=
|
||||
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
|
||||
sum[n][y][0] += __bfloat162float(biases[n][y]);
|
||||
}
|
||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
|
||||
}
|
||||
@@ -2259,9 +2211,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||
int A_CHUNK, int UNRL, int N>
|
||||
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
||||
const int Bx, const int By, const fp8_t* B,
|
||||
const fp8_t* __restrict__ A,
|
||||
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
|
||||
const int M, const int Bx, const int By,
|
||||
const fp8_t* B, const fp8_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ BIAS, scalar_t* C,
|
||||
const float* __restrict__ s_A,
|
||||
const float* __restrict__ s_B, const int _WvPrGrp,
|
||||
@@ -2270,17 +2222,18 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
|
||||
}
|
||||
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
||||
|
||||
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
|
||||
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
|
||||
const at::Tensor& scale_a, const at::Tensor& scale_b,
|
||||
const int64_t CuCount) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
auto M_in = in_a.size(0);
|
||||
auto K_in = in_a.size(1);
|
||||
auto N_in = in_b.size(0);
|
||||
auto Kp_in = in_a.stride(0);
|
||||
auto M_in = in_b.size(0);
|
||||
auto K_in = in_b.size(1);
|
||||
auto N_in = in_a.size(0);
|
||||
auto Kap_in = in_a.stride(0);
|
||||
auto Kbp_in = in_b.stride(0);
|
||||
auto Bx_in =
|
||||
(in_bias.has_value() && in_bias->numel() > 0)
|
||||
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
|
||||
@@ -2300,23 +2253,22 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const int max_lds_len = get_lds_size();
|
||||
|
||||
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
||||
_N) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
||||
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
|
||||
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
|
||||
__wvPrGrp, CuCount); \
|
||||
} else { \
|
||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
||||
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
|
||||
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
|
||||
__wvPrGrp, CuCount); \
|
||||
} \
|
||||
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
|
||||
{ \
|
||||
dim3 block(64, _WvPrGrp); \
|
||||
if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16)); \
|
||||
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
|
||||
s_a, s_b, __wvPrGrp, CuCount); \
|
||||
} else { \
|
||||
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16)); \
|
||||
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
|
||||
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
|
||||
s_a, s_b, __wvPrGrp, CuCount); \
|
||||
} \
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] {
|
||||
@@ -2332,16 +2284,16 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||
: nullptr;
|
||||
switch (N_in) {
|
||||
case 1:
|
||||
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1)
|
||||
WVSPLITKQ(12, 2, 2, 2, 2, 1)
|
||||
break;
|
||||
case 2:
|
||||
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2)
|
||||
WVSPLITKQ(12, 2, 2, 2, 2, 2)
|
||||
break;
|
||||
case 3:
|
||||
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3)
|
||||
WVSPLITKQ(8, 2, 2, 1, 1, 3)
|
||||
break;
|
||||
case 4:
|
||||
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4)
|
||||
WVSPLITKQ(4, 2, 2, 1, 1, 4)
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
@@ -73,21 +73,40 @@ NKM_FACTORS_WVSPLITKRC = [
|
||||
NKM_FACTORS_WVSPLITK_FP8 = [
|
||||
# FP8-specific cases with K % 16 == 0
|
||||
(1, 16, 16),
|
||||
(1, 32, 16 + 16),
|
||||
(1, 64, 64),
|
||||
(1, 64, 64 + 16),
|
||||
(1, 64 + 16, 64),
|
||||
(1, 64 + 16, 64 + 16),
|
||||
(4, 64, 64),
|
||||
(4, 64, 64 + 16),
|
||||
(4, 64 + 16, 64),
|
||||
(4, 64 + 16, 64 + 16),
|
||||
(2, 512, 512),
|
||||
(3, 512, 512),
|
||||
(3, 512, 512 + 16),
|
||||
(4, 512, 512),
|
||||
(3, 2048, 2048),
|
||||
(3, 2048, 2048 + 16),
|
||||
(4, 2048 + 16, 2048),
|
||||
(4, 2048 + 16, 2048 + 16),
|
||||
(4, 4096, 4096),
|
||||
(4, 16400, 2048),
|
||||
(4, 16400, 2048 + 16),
|
||||
# Extended FP8 dimensions not covered by WVSPLITK
|
||||
(1, 14336, 1024),
|
||||
(2, 24576, 2048),
|
||||
(4, 32768, 28672),
|
||||
(4, 32768 * 2, 28672),
|
||||
(4, 32768 * 2, 28672 + 16),
|
||||
(4, 32768 * 2 + 16, 28672),
|
||||
(4, 32768 * 2 + 16, 28672 + 16),
|
||||
]
|
||||
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def pad_weights_fp8(weight):
|
||||
def pad_fp8(weight):
|
||||
num_pad = 256 // weight.element_size()
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -195,72 +214,41 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("xnorm", [False, True])
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("padded", [False, True])
|
||||
@pytest.mark.parametrize("padded_a", [False, True])
|
||||
@pytest.mark.parametrize("padded_b", [False, True])
|
||||
@pytest.mark.parametrize("biased", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8",
|
||||
)
|
||||
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed, padded):
|
||||
def test_rocm_wvsplitk_fp8_kernel(
|
||||
xnorm, n, k, m, dtype, seed, padded_a, padded_b, biased
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
A = torch.rand(n, k, device="cuda") - 0.5
|
||||
B = torch.rand(m, k, device="cuda") - 0.5
|
||||
xavier = math.sqrt(2 / k) if xnorm else 1 # normalize to avoid large deltas
|
||||
A = (torch.rand(n, k, device="cuda") * 2 - 1) * xavier
|
||||
B = (torch.rand(m, k, device="cuda") * 2 - 1) * xavier
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
if padded:
|
||||
B = pad_weights_fp8(B)
|
||||
if padded_b:
|
||||
B = pad_fp8(B)
|
||||
if padded_a:
|
||||
A = pad_fp8(A)
|
||||
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
|
||||
)
|
||||
out = ops.wvSplitKQ(
|
||||
B,
|
||||
A,
|
||||
dtype,
|
||||
scale_a,
|
||||
scale_b,
|
||||
get_cu_count(),
|
||||
)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("padded", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8",
|
||||
)
|
||||
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed, padded):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, device="cuda") - 0.5) * xavier
|
||||
B = (torch.rand(m, k, device="cuda") - 0.5) * xavier
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
if padded:
|
||||
B = pad_weights_fp8(B)
|
||||
BIAS = None if (not biased) else (torch.rand(m, dtype=dtype, device="cuda") * 2 - 1)
|
||||
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
|
||||
)
|
||||
out = ops.wvSplitKQ(
|
||||
B,
|
||||
A,
|
||||
dtype,
|
||||
scale_a,
|
||||
scale_b,
|
||||
get_cu_count(),
|
||||
BIAS,
|
||||
)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
if xnorm:
|
||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
|
||||
else:
|
||||
assert torch.allclose(out, ref_out, 0.01)
|
||||
|
||||
@@ -25,10 +25,10 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
A.shape[0] == 1
|
||||
and B.shape[1] % 16 == 0
|
||||
A.shape[0] <= 4
|
||||
and B.shape[0] % 16 == 0 # M TODO: needed?
|
||||
and B.shape[1] % 16 == 0 # K
|
||||
and ((bias is None) or (bias.dtype == out_dtype))
|
||||
and A.is_contiguous()
|
||||
):
|
||||
output = ops.wvSplitKQ(
|
||||
B.t(),
|
||||
|
||||
Reference in New Issue
Block a user