Adds padding and perf improvements to wvSplitK_fp8 (#33527)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
Hashem Hashemi
2026-02-05 14:16:02 -08:00
committed by GitHub
parent 42d5d705f9
commit d5c4800112
3 changed files with 169 additions and 229 deletions

View File

@@ -1899,8 +1899,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_sml_(const int K, const int Kp, const int M, const int Bx,
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp, const int M,
const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
@@ -1924,9 +1925,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
__shared__ fp8_t s[max_lds_len];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
}
asm volatile("s_waitcnt vmcnt(0)");
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
@@ -1934,37 +1940,24 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
floatx16 sum[N][YTILE];
float sA = *s_A;
float sB = *s_B;
while (m < M) {
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = {0.f};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
floatx16 sum[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
#pragma unroll
for (uint32_t n = 0; n < N; ++n) bigA[n][k2].h8 = {0.f};
#pragma unroll
for (uint32_t y = 0; y < YTILE; ++y) bigB[y][k2].h8 = {0.f};
}
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
#pragma unroll
for (uint32_t y = 0; y < YTILE; ++y) {
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
}
}
@@ -1975,16 +1968,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (int n = 0; n < N; n++) {
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
if (k >= K) break;
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
@@ -2002,48 +1992,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
1); // row_shl8
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
1); // row_shl9
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
1); // row_shl10
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
1); // row_shl11
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
@@ -2051,19 +2020,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
if (threadIdx.x == 0) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
sum[n][y][0] *= sA * sB;
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
sum[n][y][0] += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][y][0] +=
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
sum[n][y][0] += __bfloat162float(biases[n][y]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]); // * sA * sB);
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
}
}
}
@@ -2074,9 +2047,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
const int M, const int Bx, const int By,
const fp8_t* B, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS,
scalar_t* C, const float* __restrict__ s_A,
const float* __restrict__ s_B,
@@ -2089,8 +2062,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kp, const int M,
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitKQ_hf_(const int K, const int Kp, const int M, const int Bx,
const int By, const fp8_t* B, const fp8_t* __restrict__ A,
wvSplitKQ_hf_(const int K, const int Kap, const int Kbp, const int M,
const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A, const float* __restrict__ s_B,
const int _WvPrGrp, const int CuCount) {
@@ -2113,9 +2087,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
__shared__ fp8_t s[max_lds_len];
for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK;
k < min__(K * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) {
#if defined(__gfx950__)
__builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0);
#else
*((bigType*)(&s[k])) = *((bigType*)(&A[k]));
#endif
}
asm volatile("s_waitcnt vmcnt(0)");
__syncthreads();
if (threadIdx.y >= _WvPrGrp) return;
@@ -2123,29 +2102,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
floatx16 sum[N][YTILE];
float sA = *s_A;
float sB = *s_B;
while (m < M) {
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++) sum[n][i] = {0};
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
floatx16 sum[N][YTILE] = {};
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
bigType bigA[N][UNRL] = {};
bigType bigB[YTILE][UNRL];
// Fetch the weight matrix from memory!
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
const fp8_t* B_ = &B[(m + 0) * Kp + k_];
const fp8_t* B_ = &B[min__(k_, K - A_CHUNK)];
for (int y = 0; y < YTILE; ++y) {
if (y + m >= M) break; // To avoid mem access fault.
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * Kp])));
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp])));
}
}
@@ -2156,20 +2129,16 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (int n = 0; n < N; n++) {
if (k_ + K * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n])));
if (k_ + Kap * n < max_lds_len)
bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n])));
else
bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n])));
bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n])));
}
}
// Do the matrix multiplication in interleaved manner
#pragma unroll
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
uint32_t k = k1 + k2 * THRDS * A_CHUNK;
uint32_t k_ = k + threadIdx.x * A_CHUNK;
if (k_ >= K) break;
for (uint32_t n = 0; n < N; n++) {
for (int i = 0; i < A_CHUNK; i += 8) {
for (int y = 0; y < YTILE; ++y) {
@@ -2187,48 +2156,27 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
for (int y = 0; y < YTILE; y++) {
float accm0 = sum[n][y][0];
float accm16 = sum[n][y][8];
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][1]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][9]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][2]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][10]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][3]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][11]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][4]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][12]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][5]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:9 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][13]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][6]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:10 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][14]), "v"(accm16));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm0)
: "0"(accm0), "v"(sum[n][y][7]), "v"(accm0));
asm("v_add_f32 %0, %2, %3 row_shl:11 bound_ctrl:0 "
: "=v"(accm16)
: "0"(accm16), "v"(sum[n][y][15]), "v"(accm16));
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
1); // row_shl3
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
1); // row_shl8
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
1); // row_shl9
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
1); // row_shl10
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
1); // row_shl11
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
accm0 += __shfl(accm0, 36);
accm16 += __shfl(accm16, 52);
sum[n][y][0] = accm0 + __shfl(accm16, 16);
@@ -2236,17 +2184,21 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
}
if (threadIdx.x == 0) {
scalar_t biases[N][YTILE] = {};
if (BIAS)
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx];
}
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
if (y + m >= M) break; // To avoid mem access fault.
sum[n][y][0] *= sA * sB;
if constexpr (std::is_same_v<scalar_t, half>) {
if (BIAS)
sum[n][y][0] += __half2float(BIAS[(m + y) % Bx + (n % By) * M]);
sum[n][y][0] += __half2float(biases[n][y]);
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
if (BIAS)
sum[n][y][0] +=
__bfloat162float(BIAS[(m + y) % Bx + (n % By) * M]);
sum[n][y][0] += __bfloat162float(biases[n][y]);
}
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
}
@@ -2259,9 +2211,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
int A_CHUNK, int UNRL, int N>
__global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
const int Bx, const int By, const fp8_t* B,
const fp8_t* __restrict__ A,
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
const int M, const int Bx, const int By,
const fp8_t* B, const fp8_t* __restrict__ A,
const scalar_t* __restrict__ BIAS, scalar_t* C,
const float* __restrict__ s_A,
const float* __restrict__ s_B, const int _WvPrGrp,
@@ -2270,17 +2222,18 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M,
}
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
auto M_in = in_a.size(0);
auto K_in = in_a.size(1);
auto N_in = in_b.size(0);
auto Kp_in = in_a.stride(0);
auto M_in = in_b.size(0);
auto K_in = in_b.size(1);
auto N_in = in_a.size(0);
auto Kap_in = in_a.stride(0);
auto Kbp_in = in_b.stride(0);
auto Bx_in =
(in_bias.has_value() && in_bias->numel() > 0)
? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0)
@@ -2300,23 +2253,22 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int max_lds_len = get_lds_size();
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
_N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
__wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, Bx_in, By_in, a_ptr, \
b_ptr, bias_ptr, c_ptr, s_a, s_b, \
__wvPrGrp, CuCount); \
} \
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
{ \
dim3 block(64, _WvPrGrp); \
if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16)); \
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} else { \
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16)); \
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
s_a, s_b, __wvPrGrp, CuCount); \
} \
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] {
@@ -2332,16 +2284,16 @@ void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
: nullptr;
switch (N_in) {
case 1:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 1)
WVSPLITKQ(12, 2, 2, 2, 2, 1)
break;
case 2:
WVSPLITKQ(16, 2, 2, 2, 2, 2, 2, 2)
WVSPLITKQ(12, 2, 2, 2, 2, 2)
break;
case 3:
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 3)
WVSPLITKQ(8, 2, 2, 1, 1, 3)
break;
case 4:
WVSPLITKQ(16, 4, 7, 7, 1, 1, 1, 4)
WVSPLITKQ(4, 2, 2, 1, 1, 4)
break;
default:
throw std::runtime_error(

View File

@@ -73,21 +73,40 @@ NKM_FACTORS_WVSPLITKRC = [
NKM_FACTORS_WVSPLITK_FP8 = [
# FP8-specific cases with K % 16 == 0
(1, 16, 16),
(1, 32, 16 + 16),
(1, 64, 64),
(1, 64, 64 + 16),
(1, 64 + 16, 64),
(1, 64 + 16, 64 + 16),
(4, 64, 64),
(4, 64, 64 + 16),
(4, 64 + 16, 64),
(4, 64 + 16, 64 + 16),
(2, 512, 512),
(3, 512, 512),
(3, 512, 512 + 16),
(4, 512, 512),
(3, 2048, 2048),
(3, 2048, 2048 + 16),
(4, 2048 + 16, 2048),
(4, 2048 + 16, 2048 + 16),
(4, 4096, 4096),
(4, 16400, 2048),
(4, 16400, 2048 + 16),
# Extended FP8 dimensions not covered by WVSPLITK
(1, 14336, 1024),
(2, 24576, 2048),
(4, 32768, 28672),
(4, 32768 * 2, 28672),
(4, 32768 * 2, 28672 + 16),
(4, 32768 * 2 + 16, 28672),
(4, 32768 * 2 + 16, 28672 + 16),
]
SEEDS = [0]
def pad_weights_fp8(weight):
def pad_fp8(weight):
num_pad = 256 // weight.element_size()
import torch.nn.functional as F
@@ -195,72 +214,41 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("padded", [False, True])
@pytest.mark.parametrize("padded_a", [False, True])
@pytest.mark.parametrize("padded_b", [False, True])
@pytest.mark.parametrize("biased", [False, True])
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8",
)
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed, padded):
def test_rocm_wvsplitk_fp8_kernel(
xnorm, n, k, m, dtype, seed, padded_a, padded_b, biased
):
torch.manual_seed(seed)
A = torch.rand(n, k, device="cuda") - 0.5
B = torch.rand(m, k, device="cuda") - 0.5
xavier = math.sqrt(2 / k) if xnorm else 1 # normalize to avoid large deltas
A = (torch.rand(n, k, device="cuda") * 2 - 1) * xavier
B = (torch.rand(m, k, device="cuda") * 2 - 1) * xavier
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
if padded:
B = pad_weights_fp8(B)
if padded_b:
B = pad_fp8(B)
if padded_a:
A = pad_fp8(A)
ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
)
out = ops.wvSplitKQ(
B,
A,
dtype,
scale_a,
scale_b,
get_cu_count(),
)
assert torch.allclose(out, ref_out, rtol=0.01)
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("padded", [False, True])
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8",
)
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed, padded):
torch.manual_seed(seed)
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, device="cuda") - 0.5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
if padded:
B = pad_weights_fp8(B)
BIAS = None if (not biased) else (torch.rand(m, dtype=dtype, device="cuda") * 2 - 1)
ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
)
out = ops.wvSplitKQ(
B,
A,
dtype,
scale_a,
scale_b,
get_cu_count(),
BIAS,
)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS)
assert torch.allclose(out, ref_out, rtol=0.01)
if xnorm:
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
else:
assert torch.allclose(out, ref_out, 0.01)

View File

@@ -25,10 +25,10 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
bias: torch.Tensor,
) -> torch.Tensor:
if (
A.shape[0] == 1
and B.shape[1] % 16 == 0
A.shape[0] <= 4
and B.shape[0] % 16 == 0 # M TODO: needed?
and B.shape[1] % 16 == 0 # K
and ((bias is None) or (bias.dtype == out_dtype))
and A.is_contiguous()
):
output = ops.wvSplitKQ(
B.t(),