Convert wvSplitKQ to 16x16 MFMA in prep for mi4xx. (#34100)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
@@ -1902,7 +1902,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
float sB = *s_B;
|
||||
|
||||
while (m < M) {
|
||||
floatx16 sum[N][YTILE] = {};
|
||||
scalar8 sum[N][YTILE] = {};
|
||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||
bigType bigA[N][UNRL] = {};
|
||||
bigType bigB[YTILE][UNRL];
|
||||
@@ -1936,7 +1936,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (uint32_t n = 0; n < N; n++) {
|
||||
for (int i = 0; i < A_CHUNK; i += 8) {
|
||||
for (int y = 0; y < YTILE; ++y) {
|
||||
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
|
||||
0);
|
||||
}
|
||||
@@ -1949,31 +1949,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
float accm0 = sum[n][y][0];
|
||||
float accm16 = sum[n][y][8];
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
|
||||
1); // row_shl1
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
|
||||
1); // row_shl2
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
|
||||
1); // row_shl3
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
|
||||
1); // row_shl8
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
|
||||
1); // row_shl9
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
|
||||
1); // row_shl10
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
|
||||
1); // row_shl11
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
|
||||
accm0 += __shfl(accm0, 36);
|
||||
accm16 += __shfl(accm16, 52);
|
||||
sum[n][y][0] = accm0 + __shfl(accm16, 16);
|
||||
accm0 += __shfl_down(accm0, 20);
|
||||
accm0 += __shfl_down(accm0, 40);
|
||||
sum[n][y][0] = accm0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2064,7 +2048,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
float sB = *s_B;
|
||||
|
||||
while (m < M) {
|
||||
floatx16 sum[N][YTILE] = {};
|
||||
scalar8 sum[N][YTILE] = {};
|
||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||
bigType bigA[N][UNRL] = {};
|
||||
bigType bigB[YTILE][UNRL];
|
||||
@@ -2100,7 +2084,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (uint32_t n = 0; n < N; n++) {
|
||||
for (int i = 0; i < A_CHUNK; i += 8) {
|
||||
for (int y = 0; y < YTILE; ++y) {
|
||||
sum[n][y] = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bigA[n][k2].l[i / 8], bigB[y][k2].l[i / 8], sum[n][y], 0, 0,
|
||||
0);
|
||||
}
|
||||
@@ -2113,31 +2097,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int y = 0; y < YTILE; y++) {
|
||||
float accm0 = sum[n][y][0];
|
||||
float accm16 = sum[n][y][8];
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][1], 0x101, 0xf, 0xf,
|
||||
1); // row_shl1
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
|
||||
1); // row_shl2
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
|
||||
1); // row_shl3
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][11], 0x103, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][4], 0x108, 0xf, 0xf,
|
||||
1); // row_shl8
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][12], 0x108, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][5], 0x109, 0xf, 0xf,
|
||||
1); // row_shl9
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][13], 0x109, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][6], 0x10a, 0xf, 0xf,
|
||||
1); // row_shl10
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][14], 0x10a, 0xf, 0xf, 1);
|
||||
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][7], 0x10b, 0xf, 0xf,
|
||||
1); // row_shl11
|
||||
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][15], 0x10b, 0xf, 0xf, 1);
|
||||
accm0 += __shfl(accm0, 36);
|
||||
accm16 += __shfl(accm16, 52);
|
||||
sum[n][y][0] = accm0 + __shfl(accm16, 16);
|
||||
accm0 += __shfl_down(accm0, 20);
|
||||
accm0 += __shfl_down(accm0, 40);
|
||||
sum[n][y][0] = accm0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2242,16 +2210,16 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
|
||||
: nullptr;
|
||||
switch (N_in) {
|
||||
case 1:
|
||||
WVSPLITKQ(12, 2, 2, 2, 2, 1)
|
||||
WVSPLITKQ(16, 2, 2, 2, 2, 1)
|
||||
break;
|
||||
case 2:
|
||||
WVSPLITKQ(12, 2, 2, 2, 2, 2)
|
||||
WVSPLITKQ(16, 2, 2, 2, 2, 2)
|
||||
break;
|
||||
case 3:
|
||||
WVSPLITKQ(8, 2, 2, 1, 1, 3)
|
||||
WVSPLITKQ(16, 2, 2, 2, 2, 3)
|
||||
break;
|
||||
case 4:
|
||||
WVSPLITKQ(4, 2, 2, 1, 1, 4)
|
||||
WVSPLITKQ(16, 2, 2, 2, 2, 4)
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
|
||||
Reference in New Issue
Block a user