[ROCm] Enable wvSplitK skinny GEMM kernel for RDNA4/gfx1x decode (#34709)
Signed-off-by: L.B.R. <lbr@mmonad.com> Co-authored-by: L.B.R. <lbr@mmonad.com>
This commit is contained in:
@@ -26,6 +26,16 @@
|
|||||||
#define __HIP__GFX9__
|
#define __HIP__GFX9__
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__HIPCC__) && \
|
||||||
|
(defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1150__) || \
|
||||||
|
defined(__gfx1151__) || defined(__gfx1200__) || defined(__gfx1201__))
|
||||||
|
#define __HIP__GFX1X__
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__HIPCC__) && (defined(__gfx1200__) || defined(__gfx1201__))
|
||||||
|
#define __HIP__GFX12__
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
|
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
|
||||||
#define __HIP__MI3XX__
|
#define __HIP__MI3XX__
|
||||||
#endif
|
#endif
|
||||||
@@ -37,15 +47,31 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
int get_lds_size() {
|
int get_lds_size() {
|
||||||
static bool is_cached = false;
|
static const int result = [] {
|
||||||
static int result;
|
const auto* dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
if (is_cached == false) {
|
const std::string device_arch = dprops->gcnArchName;
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
return device_arch.find("gfx95") == std::string::npos ? 64 * 1024
|
||||||
std::string device_arch = dprops->gcnArchName;
|
: 160 * 1024;
|
||||||
size_t substring = device_arch.find("gfx95");
|
}();
|
||||||
result = (substring == std::string::npos ? 64 * 1024 : 160 * 1024);
|
return result;
|
||||||
is_cached = true;
|
}
|
||||||
}
|
|
||||||
|
bool on_gfx1x() {
|
||||||
|
static const bool result = [] {
|
||||||
|
const auto* dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
|
const std::string device_arch = dprops->gcnArchName;
|
||||||
|
return device_arch.find("gfx11") != std::string::npos ||
|
||||||
|
device_arch.find("gfx12") != std::string::npos;
|
||||||
|
}();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool on_gfx12() {
|
||||||
|
static const bool result = [] {
|
||||||
|
const auto* dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
|
const std::string device_arch = dprops->gcnArchName;
|
||||||
|
return device_arch.find("gfx12") != std::string::npos;
|
||||||
|
}();
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -286,21 +312,35 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
|
|||||||
return out_c;
|
return out_c;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DOT2C(V0, V2, V3) \
|
#if defined(__HIP__GFX9__) && !defined(__HIP__GFX1X__)
|
||||||
if constexpr (std::is_same_v<scalar_t, half>) { \
|
#define DOT2C(V0, V2, V3) \
|
||||||
asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(V0) : "0"(V0), "v"(V2), "v"(V3)); \
|
if constexpr (std::is_same_v<scalar_t, half>) { \
|
||||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \
|
asm("v_dot2c_f32_f16 %0, %2, %3" \
|
||||||
float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \
|
: "=v"(V0) \
|
||||||
__bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \
|
: "0"(V0), "v"(V2), "v"(V3)); \
|
||||||
V0 += (s.x + s.y); \
|
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \
|
||||||
}
|
float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \
|
||||||
|
__bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \
|
||||||
|
V0 += (s.x + s.y); \
|
||||||
|
}
|
||||||
|
#elif defined(__HIP__GFX1X__)
|
||||||
|
// gfx1x: v_dot2_f32_f16 (VOP3-P, dot10-insts, available on gfx11+gfx12)
|
||||||
|
#define DOT2C(V0, V2, V3) \
|
||||||
|
if constexpr (std::is_same_v<scalar_t, half>) { \
|
||||||
|
asm("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(V0) : "v"(V2), "v"(V3)); \
|
||||||
|
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) { \
|
||||||
|
float2 s = __bfloat1622float2(*((__hip_bfloat162*)(&(V2)))) * \
|
||||||
|
__bfloat1622float2(*((__hip_bfloat162*)(&(V3)))); \
|
||||||
|
V0 += (s.x + s.y); \
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// To avoid LLVM silently upcasting to double
|
// To avoid LLVM silently upcasting to double
|
||||||
__device__ inline unsigned int min__(uint32_t a, uint32_t b) {
|
__device__ inline unsigned int min__(uint32_t a, uint32_t b) {
|
||||||
return min(a, b);
|
return min(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
#if defined(__HIP__GFX9__) || defined(__HIP__GFX1X__)
|
||||||
// This version targets cases where A[] fits LDS capacity
|
// This version targets cases where A[] fits LDS capacity
|
||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
@@ -442,14 +482,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
1); // row_shr2
|
1); // row_shr2
|
||||||
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
|
||||||
1); // row_shr1
|
1); // row_shr1
|
||||||
|
#if defined(__HIP__GFX9__)
|
||||||
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
|
||||||
1); // ROW_BCAST15
|
1); // ROW_BCAST15
|
||||||
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
|
||||||
1); // ROW_BCAST31
|
1); // ROW_BCAST31
|
||||||
|
#else
|
||||||
|
sum[n][y] += __shfl_xor(sum[n][y], 16);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == (THRDS - 1)) {
|
||||||
scalar_t biases[N][YTILE] = {};
|
scalar_t biases[N][YTILE] = {};
|
||||||
if (BIAS)
|
if (BIAS)
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
@@ -469,9 +513,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#ifdef __HIP__GFX9__
|
||||||
|
#pragma unroll
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
/*float accm1 = 0;
|
/*float accm1 = 0;
|
||||||
for (int i=0; i<64; i++)
|
for (int i=0; i<64; i++)
|
||||||
@@ -498,7 +543,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
sum4[n][y][0] = accm;
|
sum4[n][y][0] = accm;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == (THRDS - 1)) {
|
||||||
scalar_t biases[N][YTILE] = {};
|
scalar_t biases[N][YTILE] = {};
|
||||||
if (BIAS)
|
if (BIAS)
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
@@ -513,11 +558,12 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // __HIP__GFX9__ (MFMA path)
|
||||||
}
|
}
|
||||||
m += CuCount * _WvPrGrp * YTILE;
|
m += CuCount * _WvPrGrp * YTILE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
#else
|
||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
__global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap,
|
__global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap,
|
||||||
@@ -528,9 +574,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap,
|
|||||||
const int _WvPrGrp, const int CuCount) {
|
const int _WvPrGrp, const int CuCount) {
|
||||||
UNREACHABLE_CODE
|
UNREACHABLE_CODE
|
||||||
}
|
}
|
||||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
#endif
|
||||||
|
|
||||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
#if defined(__HIP__GFX9__) || defined(__HIP__GFX1X__)
|
||||||
// This version targets cases where A[] marginally exceeds LDS capacity
|
// This version targets cases where A[] marginally exceeds LDS capacity
|
||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
@@ -657,14 +703,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
1); // row_shr2
|
1); // row_shr2
|
||||||
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
|
||||||
1); // row_shr1
|
1); // row_shr1
|
||||||
|
#if defined(__HIP__GFX9__)
|
||||||
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
|
||||||
1); // ROW_BCAST15
|
1); // ROW_BCAST15
|
||||||
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
|
||||||
1); // ROW_BCAST31
|
1); // ROW_BCAST31
|
||||||
|
#else
|
||||||
|
sum[n][y] += __shfl_xor(sum[n][y], 16);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == (THRDS - 1)) {
|
||||||
scalar_t biases[N][YTILE] = {};
|
scalar_t biases[N][YTILE] = {};
|
||||||
if (BIAS)
|
if (BIAS)
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
@@ -686,9 +736,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#ifdef __HIP__GFX9__
|
||||||
|
#pragma unroll
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
// float accm1 = 0;
|
// float accm1 = 0;
|
||||||
// for (int i=0; i<64; i++)
|
// for (int i=0; i<64; i++)
|
||||||
@@ -713,7 +764,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
sum4[n][y][0] = accm;
|
sum4[n][y][0] = accm;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == (THRDS - 1)) {
|
||||||
scalar_t biases[N][YTILE] = {};
|
scalar_t biases[N][YTILE] = {};
|
||||||
if (BIAS)
|
if (BIAS)
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
@@ -730,6 +781,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // __HIP__GFX9__ (MFMA path)
|
||||||
}
|
}
|
||||||
|
|
||||||
m += CuCount * _WvPrGrp * YTILE;
|
m += CuCount * _WvPrGrp * YTILE;
|
||||||
@@ -746,7 +798,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
#else
|
||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
__global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap,
|
__global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap,
|
||||||
@@ -756,9 +808,9 @@ __global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap,
|
|||||||
const int _WvPrGrp, const int CuCount) {
|
const int _WvPrGrp, const int CuCount) {
|
||||||
UNREACHABLE_CODE
|
UNREACHABLE_CODE
|
||||||
}
|
}
|
||||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
#endif
|
||||||
|
|
||||||
#if defined(__HIP__GFX9__) // TODO: Add NAVI support
|
#if defined(__HIP__GFX9__) || defined(__HIP__GFX1X__)
|
||||||
// This version targets big A[] cases, where it is much larger than LDS capacity
|
// This version targets big A[] cases, where it is much larger than LDS capacity
|
||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
@@ -1004,14 +1056,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
1); // row_shr2
|
1); // row_shr2
|
||||||
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf,
|
||||||
1); // row_shr1
|
1); // row_shr1
|
||||||
|
#if defined(__HIP__GFX9__)
|
||||||
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf,
|
||||||
1); // ROW_BCAST15
|
1); // ROW_BCAST15
|
||||||
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
|
sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf,
|
||||||
1); // ROW_BCAST31
|
1); // ROW_BCAST31
|
||||||
|
#else
|
||||||
|
sum[n][y] += __shfl_xor(sum[n][y], 16);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == (THRDS - 1)) {
|
||||||
scalar_t biases[N][YTILE] = {};
|
scalar_t biases[N][YTILE] = {};
|
||||||
if (BIAS)
|
if (BIAS)
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
@@ -1033,9 +1089,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#ifdef __HIP__GFX9__
|
||||||
|
#pragma unroll
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
float accm = sum4[n][y][0];
|
float accm = sum4[n][y][0];
|
||||||
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
|
accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf,
|
||||||
@@ -1057,7 +1114,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
sum4[n][y][0] = accm;
|
sum4[n][y][0] = accm;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (threadIdx.x == 63) {
|
if (threadIdx.x == (THRDS - 1)) {
|
||||||
scalar_t biases[N][YTILE] = {};
|
scalar_t biases[N][YTILE] = {};
|
||||||
if (BIAS)
|
if (BIAS)
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
@@ -1074,6 +1131,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // __HIP__GFX9__ (MFMA path)
|
||||||
}
|
}
|
||||||
|
|
||||||
m += CuCount * _WvPrGrp * YTILE;
|
m += CuCount * _WvPrGrp * YTILE;
|
||||||
@@ -1090,7 +1148,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
#else
|
||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N>
|
int UNRL, int N>
|
||||||
__global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap,
|
__global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap,
|
||||||
@@ -1101,7 +1159,7 @@ __global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap,
|
|||||||
const int _WvPrGrp, const int CuCount) {
|
const int _WvPrGrp, const int CuCount) {
|
||||||
UNREACHABLE_CODE
|
UNREACHABLE_CODE
|
||||||
}
|
}
|
||||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
#endif
|
||||||
|
|
||||||
// Find the min val of div2 that doesn't increase N/(div1*div2)
|
// Find the min val of div2 that doesn't increase N/(div1*div2)
|
||||||
int mindiv(int N, int div1, int div2) {
|
int mindiv(int N, int div1, int div2) {
|
||||||
@@ -1148,40 +1206,40 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
const int max_lds_len = get_lds_size() / 2;
|
const int max_lds_len = get_lds_size() / 2;
|
||||||
|
|
||||||
#define WVSPLITK(_YTILE, _UNRL, _N) \
|
#define WVSPLITK_CFG(_THRDS, _WVPRGRP, _YTILE, _UNRL, _N) \
|
||||||
{ \
|
{ \
|
||||||
dim3 block(64, 16); \
|
dim3 block(_THRDS, _WVPRGRP); \
|
||||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
|
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, _WVPRGRP); \
|
||||||
if ((Kbp_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
|
if ((Kbp_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
|
||||||
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
wvSplitK_hf_sml_<fptype, _THRDS, _YTILE, _WVPRGRP, 8, _UNRL, _N> \
|
||||||
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||||
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
|
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
|
||||||
CuCount); \
|
CuCount); \
|
||||||
else if (Kbp_in * N_in <= max_lds_len * 1.2) \
|
else if (Kbp_in * N_in <= max_lds_len * 1.2) \
|
||||||
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
wvSplitK_hf_<fptype, _THRDS, _YTILE, _WVPRGRP, 8, _UNRL, _N> \
|
||||||
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||||
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
|
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
|
||||||
CuCount); \
|
CuCount); \
|
||||||
else \
|
else \
|
||||||
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
wvSplitK_hf_big_<fptype, _THRDS, _YTILE, _WVPRGRP, 8, _UNRL, _N> \
|
||||||
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||||
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
|
By_in, af4, bf4, biasf4, c, __wvPrGrp, \
|
||||||
CuCount); \
|
CuCount); \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define WVSPLIT_TILE(_sYT, __N) \
|
#define WVSPLIT_TILE_CFG(_THRDS, _WVPRGRP, _sYT, __N) \
|
||||||
{ \
|
{ \
|
||||||
bool fit_lds = (Kbp_in * N_in <= max_lds_len); \
|
bool fit_lds = (Kbp_in * N_in <= max_lds_len); \
|
||||||
if (_sYT <= 1) \
|
if (_sYT <= 1) \
|
||||||
WVSPLITK(1, 4, __N) \
|
WVSPLITK_CFG(_THRDS, _WVPRGRP, 1, 4, __N) \
|
||||||
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
|
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
|
||||||
WVSPLITK(2, 2, __N) \
|
WVSPLITK_CFG(_THRDS, _WVPRGRP, 2, 2, __N) \
|
||||||
else if (_sYT <= 4 * 3) \
|
else if (_sYT <= 4 * 3) \
|
||||||
WVSPLITK(3, 2, __N) \
|
WVSPLITK_CFG(_THRDS, _WVPRGRP, 3, 2, __N) \
|
||||||
else if (__N == 4) \
|
else if (__N == 4) \
|
||||||
WVSPLITK(4, 1, __N) \
|
WVSPLITK_CFG(_THRDS, _WVPRGRP, 4, 1, __N) \
|
||||||
else \
|
else \
|
||||||
WVSPLITK(4, 2, __N) \
|
WVSPLITK_CFG(_THRDS, _WVPRGRP, 4, 2, __N) \
|
||||||
}
|
}
|
||||||
|
|
||||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
|
||||||
@@ -1198,18 +1256,31 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
// then cut the active waves to balance their distribution...
|
// then cut the active waves to balance their distribution...
|
||||||
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
|
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
|
||||||
|
|
||||||
|
const bool use_wave32 = on_gfx1x();
|
||||||
switch (N_in) {
|
switch (N_in) {
|
||||||
case 1:
|
case 1:
|
||||||
WVSPLIT_TILE(sYT, 1)
|
if (use_wave32)
|
||||||
|
WVSPLIT_TILE_CFG(32, 16, sYT, 1)
|
||||||
|
else
|
||||||
|
WVSPLIT_TILE_CFG(64, 16, sYT, 1)
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
WVSPLIT_TILE(sYT, 2)
|
if (use_wave32)
|
||||||
|
WVSPLIT_TILE_CFG(32, 16, sYT, 2)
|
||||||
|
else
|
||||||
|
WVSPLIT_TILE_CFG(64, 16, sYT, 2)
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
WVSPLIT_TILE(sYT, 3)
|
if (use_wave32)
|
||||||
|
WVSPLIT_TILE_CFG(32, 16, sYT, 3)
|
||||||
|
else
|
||||||
|
WVSPLIT_TILE_CFG(64, 16, sYT, 3)
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
WVSPLIT_TILE(sYT, 4)
|
if (use_wave32)
|
||||||
|
WVSPLIT_TILE_CFG(32, 16, sYT, 4)
|
||||||
|
else
|
||||||
|
WVSPLIT_TILE_CFG(64, 16, sYT, 4)
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@@ -1653,7 +1724,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else // !defined(__HIP__GFX9__) TODO: Add NAVI support
|
#else
|
||||||
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
template <typename scalar_t, int THRDS, int YTILE, int WvPrGrp, int A_CHUNK,
|
||||||
int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
|
int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC>
|
||||||
__global__ void wvSplitKrc_(const int actlN, const int K, const int Kap,
|
__global__ void wvSplitKrc_(const int actlN, const int K, const int Kap,
|
||||||
@@ -1688,6 +1759,8 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
TORCH_CHECK(in_a.dtype() == torch::kFloat16 ||
|
TORCH_CHECK(in_a.dtype() == torch::kFloat16 ||
|
||||||
in_a.dtype() == torch::kBFloat16);
|
in_a.dtype() == torch::kBFloat16);
|
||||||
|
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
|
||||||
|
|
||||||
auto out_c = torch::empty(
|
auto out_c = torch::empty(
|
||||||
{N_in, M_in},
|
{N_in, M_in},
|
||||||
torch::TensorOptions().dtype(in_a.dtype()).device(in_a.device()));
|
torch::TensorOptions().dtype(in_a.dtype()).device(in_a.device()));
|
||||||
@@ -1696,7 +1769,6 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
|
|
||||||
dim3 grid(CuCount);
|
dim3 grid(CuCount);
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
// const int max_lds_len = get_lds_size() / 2;
|
// const int max_lds_len = get_lds_size() / 2;
|
||||||
|
|
||||||
@@ -1773,7 +1845,7 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
return out_c;
|
return out_c;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
|
#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
|
||||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||||
int A_CHUNK, int UNRL, int N>
|
int A_CHUNK, int UNRL, int N>
|
||||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||||
@@ -1817,12 +1889,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
|
|
||||||
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
||||||
|
|
||||||
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
|
|
||||||
float sA = *s_A;
|
float sA = *s_A;
|
||||||
float sB = *s_B;
|
float sB = *s_B;
|
||||||
|
|
||||||
while (m < M) {
|
while (m < M) {
|
||||||
|
#ifdef __HIP__GFX12__
|
||||||
|
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
|
||||||
|
float sum[N][YTILE] = {};
|
||||||
|
#else
|
||||||
|
// gfx9: MFMA accumulation
|
||||||
scalar8 sum[N][YTILE] = {};
|
scalar8 sum[N][YTILE] = {};
|
||||||
|
#endif
|
||||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||||
bigType bigA[N][UNRL] = {};
|
bigType bigA[N][UNRL] = {};
|
||||||
bigType bigB[YTILE][UNRL];
|
bigType bigB[YTILE][UNRL];
|
||||||
@@ -1854,6 +1931,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||||
for (uint32_t n = 0; n < N; n++) {
|
for (uint32_t n = 0; n < N; n++) {
|
||||||
|
#ifdef __HIP__GFX12__
|
||||||
|
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
|
||||||
|
for (int y = 0; y < YTILE; ++y) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < A_CHUNK / 4; i++) {
|
||||||
|
sum[n][y] = __builtin_amdgcn_dot4_f32_fp8_fp8(
|
||||||
|
bigA[n][k2].i[i], bigB[y][k2].i[i], sum[n][y]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// gfx9: MFMA path
|
||||||
for (int i = 0; i < A_CHUNK; i += 8) {
|
for (int i = 0; i < A_CHUNK; i += 8) {
|
||||||
for (int y = 0; y < YTILE; ++y) {
|
for (int y = 0; y < YTILE; ++y) {
|
||||||
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||||
@@ -1861,11 +1949,33 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
0);
|
0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final reduction
|
// Final reduction
|
||||||
|
#ifdef __HIP__GFX12__
|
||||||
|
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
for (int y = 0; y < YTILE; y++) {
|
||||||
|
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
||||||
|
: "=v"(sum[n][y])
|
||||||
|
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||||
|
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
||||||
|
: "=v"(sum[n][y])
|
||||||
|
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||||
|
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
||||||
|
: "=v"(sum[n][y])
|
||||||
|
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||||
|
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:1 bound_ctrl:0 "
|
||||||
|
: "=v"(sum[n][y])
|
||||||
|
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||||
|
sum[n][y] += __shfl_xor(sum[n][y], 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// gfx9 MFMA reduction
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
float accm0 = sum[n][y][0];
|
float accm0 = sum[n][y][0];
|
||||||
@@ -1880,8 +1990,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
sum[n][y][0] = accm0;
|
sum[n][y][0] = accm0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
const bool writeback_lane =
|
||||||
|
#ifdef __HIP__GFX12__
|
||||||
|
threadIdx.x == (THRDS - 1);
|
||||||
|
#else
|
||||||
|
threadIdx.x == 0;
|
||||||
|
#endif
|
||||||
|
if (writeback_lane) {
|
||||||
scalar_t biases[N][YTILE] = {};
|
scalar_t biases[N][YTILE] = {};
|
||||||
if (BIAS)
|
if (BIAS)
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
@@ -1892,13 +2009,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
if (y + m >= M) break; // To avoid mem access fault.
|
if (y + m >= M) break; // To avoid mem access fault.
|
||||||
sum[n][y][0] *= sA * sB;
|
#ifdef __HIP__GFX12__
|
||||||
|
float result = sum[n][y] * sA * sB;
|
||||||
|
#else
|
||||||
|
float result = sum[n][y][0] * sA * sB;
|
||||||
|
#endif
|
||||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||||
sum[n][y][0] += __half2float(biases[n][y]);
|
result += __half2float(biases[n][y]);
|
||||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||||
sum[n][y][0] += __bfloat162float(biases[n][y]);
|
result += __bfloat162float(biases[n][y]);
|
||||||
}
|
}
|
||||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
|
C[m + y + n * M] = __float2s<scalar_t>(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1906,7 +2027,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
m += CuCount * _WvPrGrp * YTILE;
|
m += CuCount * _WvPrGrp * YTILE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__)
|
||||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||||
int A_CHUNK, int UNRL, int N>
|
int A_CHUNK, int UNRL, int N>
|
||||||
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
|
__global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
|
||||||
@@ -1918,9 +2039,9 @@ __global__ void wvSplitKQ_hf_sml_(const int K, const int Kap, const int Kbp,
|
|||||||
const int _WvPrGrp, const int CuCount) {
|
const int _WvPrGrp, const int CuCount) {
|
||||||
UNREACHABLE_CODE
|
UNREACHABLE_CODE
|
||||||
}
|
}
|
||||||
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
|
||||||
|
|
||||||
#if defined(__HIP__MI3XX__) // TODO: Add NAVI support
|
#if defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
|
||||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||||
int A_CHUNK, int UNRL, int N>
|
int A_CHUNK, int UNRL, int N>
|
||||||
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
__global__ void __launch_bounds__(WvPrGrp* THRDS)
|
||||||
@@ -1963,12 +2084,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
|
|
||||||
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
|
||||||
|
|
||||||
using floatx16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
|
|
||||||
float sA = *s_A;
|
float sA = *s_A;
|
||||||
float sB = *s_B;
|
float sB = *s_B;
|
||||||
|
|
||||||
while (m < M) {
|
while (m < M) {
|
||||||
|
#ifdef __HIP__GFX12__
|
||||||
|
// gfx12: per-lane scalar accumulation via v_dot4_f32_fp8_fp8
|
||||||
|
float sum[N][YTILE] = {};
|
||||||
|
#else
|
||||||
|
// gfx9: MFMA accumulation
|
||||||
scalar8 sum[N][YTILE] = {};
|
scalar8 sum[N][YTILE] = {};
|
||||||
|
#endif
|
||||||
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) {
|
||||||
bigType bigA[N][UNRL] = {};
|
bigType bigA[N][UNRL] = {};
|
||||||
bigType bigB[YTILE][UNRL];
|
bigType bigB[YTILE][UNRL];
|
||||||
@@ -2002,6 +2128,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
for (uint32_t k2 = 0; k2 < UNRL; k2++) {
|
||||||
for (uint32_t n = 0; n < N; n++) {
|
for (uint32_t n = 0; n < N; n++) {
|
||||||
|
#ifdef __HIP__GFX12__
|
||||||
|
// gfx12: 4 x dot4 per A_CHUNK=16 bytes (4 FP8 per dot4)
|
||||||
|
for (int y = 0; y < YTILE; ++y) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < A_CHUNK / 4; i++) {
|
||||||
|
sum[n][y] = __builtin_amdgcn_dot4_f32_fp8_fp8(
|
||||||
|
bigA[n][k2].i[i], bigB[y][k2].i[i], sum[n][y]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// gfx9: MFMA path
|
||||||
for (int i = 0; i < A_CHUNK; i += 8) {
|
for (int i = 0; i < A_CHUNK; i += 8) {
|
||||||
for (int y = 0; y < YTILE; ++y) {
|
for (int y = 0; y < YTILE; ++y) {
|
||||||
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
sum[n][y] = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||||
@@ -2009,11 +2146,33 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
0);
|
0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final reduction
|
// Final reduction
|
||||||
|
#ifdef __HIP__GFX12__
|
||||||
|
// gfx12 wave32: DPP row_shr within 16-lane rows + cross-row shuffle
|
||||||
|
for (int n = 0; n < N; n++) {
|
||||||
|
for (int y = 0; y < YTILE; y++) {
|
||||||
|
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
|
||||||
|
: "=v"(sum[n][y])
|
||||||
|
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||||
|
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
|
||||||
|
: "=v"(sum[n][y])
|
||||||
|
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||||
|
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
|
||||||
|
: "=v"(sum[n][y])
|
||||||
|
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||||
|
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:1 bound_ctrl:0 "
|
||||||
|
: "=v"(sum[n][y])
|
||||||
|
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
|
||||||
|
sum[n][y] += __shfl_xor(sum[n][y], 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// gfx9 MFMA reduction
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
float accm0 = sum[n][y][0];
|
float accm0 = sum[n][y][0];
|
||||||
@@ -2028,8 +2187,15 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
sum[n][y][0] = accm0;
|
sum[n][y][0] = accm0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
const bool writeback_lane =
|
||||||
|
#ifdef __HIP__GFX12__
|
||||||
|
threadIdx.x == (THRDS - 1);
|
||||||
|
#else
|
||||||
|
threadIdx.x == 0;
|
||||||
|
#endif
|
||||||
|
if (writeback_lane) {
|
||||||
scalar_t biases[N][YTILE] = {};
|
scalar_t biases[N][YTILE] = {};
|
||||||
if (BIAS)
|
if (BIAS)
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
@@ -2040,13 +2206,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int y = 0; y < YTILE; y++) {
|
for (int y = 0; y < YTILE; y++) {
|
||||||
if (y + m >= M) break; // To avoid mem access fault.
|
if (y + m >= M) break; // To avoid mem access fault.
|
||||||
sum[n][y][0] *= sA * sB;
|
#ifdef __HIP__GFX12__
|
||||||
|
float result = sum[n][y] * sA * sB;
|
||||||
|
#else
|
||||||
|
float result = sum[n][y][0] * sA * sB;
|
||||||
|
#endif
|
||||||
if constexpr (std::is_same_v<scalar_t, half>) {
|
if constexpr (std::is_same_v<scalar_t, half>) {
|
||||||
sum[n][y][0] += __half2float(biases[n][y]);
|
result += __half2float(biases[n][y]);
|
||||||
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
} else if constexpr (std::is_same_v<scalar_t, __hip_bfloat16>) {
|
||||||
sum[n][y][0] += __bfloat162float(biases[n][y]);
|
result += __bfloat162float(biases[n][y]);
|
||||||
}
|
}
|
||||||
C[m + y + n * M] = __float2s<scalar_t>(sum[n][y][0]);
|
C[m + y + n * M] = __float2s<scalar_t>(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2054,7 +2224,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
|
|||||||
m += CuCount * _WvPrGrp * YTILE;
|
m += CuCount * _WvPrGrp * YTILE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else // !defined(__HIP__MI3XX__) TODO: Add NAVI support
|
#else // !defined(__HIP__MI3XX__) && !defined(__HIP__GFX12__)
|
||||||
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
template <typename scalar_t, typename fp8_t, int THRDS, int YTILE, int WvPrGrp,
|
||||||
int A_CHUNK, int UNRL, int N>
|
int A_CHUNK, int UNRL, int N>
|
||||||
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
|
__global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
|
||||||
@@ -2066,7 +2236,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kap, const int Kbp,
|
|||||||
const int CuCount) {
|
const int CuCount) {
|
||||||
UNREACHABLE_CODE
|
UNREACHABLE_CODE
|
||||||
}
|
}
|
||||||
#endif // defined(__HIP__MI3XX__) TODO: Add NAVI support
|
#endif // defined(__HIP__MI3XX__) || defined(__HIP__GFX12__)
|
||||||
|
|
||||||
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
|
void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
|
||||||
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
|
const std::optional<at::Tensor>& in_bias, at::Tensor& out_c,
|
||||||
@@ -2099,24 +2269,30 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
const int max_lds_len = get_lds_size();
|
const int max_lds_len = get_lds_size();
|
||||||
|
|
||||||
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
|
#define WVSPLITKQ_IMPL(_THRDS, _WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
|
||||||
{ \
|
{ \
|
||||||
dim3 block(64, _WvPrGrp); \
|
dim3 block(_THRDS, _WvPrGrp); \
|
||||||
if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
if ((Kap_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
||||||
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16)); \
|
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEs, 16)); \
|
||||||
wvSplitKQ_hf_sml_<fptype, fp8_t, 64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
|
wvSplitKQ_hf_sml_<fptype, fp8_t, _THRDS, _YTILEs, _WvPrGrp, 16, _UNRLs, \
|
||||||
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
_N><<<grid, block, 0, stream>>>( \
|
||||||
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
|
K_in, Kap_in, Kbp_in, M_in, Bx_in, By_in, b_ptr, a_ptr, bias_ptr, \
|
||||||
s_a, s_b, __wvPrGrp, CuCount); \
|
c_ptr, s_a, s_b, __wvPrGrp, CuCount); \
|
||||||
} else { \
|
} else { \
|
||||||
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16)); \
|
int __wvPrGrp = min(_WvPrGrp, mindiv(M_in, CuCount * _YTILEm, 16)); \
|
||||||
wvSplitKQ_hf_<fptype, fp8_t, 64, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
|
wvSplitKQ_hf_<fptype, fp8_t, _THRDS, _YTILEm, _WvPrGrp, 16, _UNRLm, _N> \
|
||||||
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
<<<grid, block, 0, stream>>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \
|
||||||
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
|
By_in, b_ptr, a_ptr, bias_ptr, c_ptr, \
|
||||||
s_a, s_b, __wvPrGrp, CuCount); \
|
s_a, s_b, __wvPrGrp, CuCount); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define WVSPLITKQ(_WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
|
||||||
|
if (on_gfx12()) \
|
||||||
|
WVSPLITKQ_IMPL(32, _WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N) \
|
||||||
|
else \
|
||||||
|
WVSPLITKQ_IMPL(64, _WvPrGrp, _YTILEs, _YTILEm, _UNRLs, _UNRLm, _N)
|
||||||
|
|
||||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] {
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_c.scalar_type(), "wvSplitKQ", [&] {
|
||||||
using fptype = typename scalar<scalar_t>::type;
|
using fptype = typename scalar<scalar_t>::type;
|
||||||
auto c_ptr = reinterpret_cast<fptype*>(out_c.data_ptr());
|
auto c_ptr = reinterpret_cast<fptype*>(out_c.data_ptr());
|
||||||
@@ -2136,10 +2312,10 @@ void wvSplitKQ(const at::Tensor& in_b, const at::Tensor& in_a,
|
|||||||
WVSPLITKQ(16, 2, 2, 2, 2, 2)
|
WVSPLITKQ(16, 2, 2, 2, 2, 2)
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
WVSPLITKQ(16, 2, 2, 2, 2, 3)
|
WVSPLITKQ(16, 2, 2, 1, 1, 3)
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
WVSPLITKQ(16, 2, 2, 2, 2, 4)
|
WVSPLITKQ(16, 2, 2, 1, 1, 4)
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|||||||
@@ -160,6 +160,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, padded_a, bias_mode
|
|||||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
|
BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1
|
||||||
elif bias_mode == 2:
|
elif bias_mode == 2:
|
||||||
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
|
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1
|
||||||
|
elif bias_mode == 3:
|
||||||
|
BIAS = torch.rand(1, m, dtype=dtype, device="cuda") * 2 - 1
|
||||||
|
|
||||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||||
out = ops.wvSplitKrc(A, B, cu_count, BIAS)
|
out = ops.wvSplitKrc(A, B, cu_count, BIAS)
|
||||||
@@ -224,10 +226,9 @@ def test_rocm_wvsplitk_kernel(
|
|||||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
|
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
|
||||||
|
|
||||||
if xnorm:
|
# Accumulation error in fp16 GEMM scales with sqrt(K)
|
||||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-8)
|
atol = torch.finfo(dtype).eps * math.sqrt(k)
|
||||||
else:
|
torch.testing.assert_close(out, ref_out, atol=atol, rtol=1e-2)
|
||||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-2)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("xnorm", [False, True])
|
@pytest.mark.parametrize("xnorm", [False, True])
|
||||||
|
|||||||
89
tests/model_executor/layers/test_rocm_unquantized_gemm.py
Normal file
89
tests/model_executor/layers/test_rocm_unquantized_gemm.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
pytest.skip(
|
||||||
|
"ROCm skinny GEMM tests are not supported on CUDA.",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers import utils
|
||||||
|
|
||||||
|
|
||||||
|
def test_rocm_unquantized_gemm_gfx1x_wvsplitk_path(monkeypatch):
|
||||||
|
x = torch.randn(1, 64, dtype=torch.float16)
|
||||||
|
weight = torch.randn(128, 64, dtype=torch.float16)
|
||||||
|
|
||||||
|
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
|
||||||
|
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
|
||||||
|
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True)
|
||||||
|
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
|
||||||
|
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False)
|
||||||
|
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
|
||||||
|
|
||||||
|
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
|
||||||
|
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
|
||||||
|
llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t())
|
||||||
|
monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock)
|
||||||
|
|
||||||
|
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
|
||||||
|
ref = torch.nn.functional.linear(x, weight, None)
|
||||||
|
|
||||||
|
wvsplitk_mock.assert_called_once()
|
||||||
|
llmm1_mock.assert_not_called()
|
||||||
|
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rocm_unquantized_gemm_gfx1x_n_gt_4_falls_back(monkeypatch):
|
||||||
|
x = torch.randn(5, 64, dtype=torch.float16)
|
||||||
|
weight = torch.randn(128, 64, dtype=torch.float16)
|
||||||
|
|
||||||
|
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
|
||||||
|
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
|
||||||
|
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: True)
|
||||||
|
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
|
||||||
|
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: False)
|
||||||
|
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
|
||||||
|
|
||||||
|
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
|
||||||
|
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
|
||||||
|
llmm1_mock = MagicMock(side_effect=lambda w, x_view, _: x_view @ w.t())
|
||||||
|
monkeypatch.setattr(utils.ops, "LLMM1", llmm1_mock)
|
||||||
|
|
||||||
|
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
|
||||||
|
ref = torch.nn.functional.linear(x, weight, None)
|
||||||
|
|
||||||
|
wvsplitk_mock.assert_not_called()
|
||||||
|
llmm1_mock.assert_not_called()
|
||||||
|
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rocm_unquantized_gemm_gfx950_wvsplitkrc_path(monkeypatch):
|
||||||
|
x = torch.randn(16, 1024, dtype=torch.float16)
|
||||||
|
weight = torch.randn(256, 1024, dtype=torch.float16)
|
||||||
|
|
||||||
|
monkeypatch.setattr(utils, "use_aiter_triton_gemm", lambda *args: False)
|
||||||
|
monkeypatch.setattr(utils.envs, "VLLM_ROCM_USE_SKINNY_GEMM", True)
|
||||||
|
monkeypatch.setattr("vllm.platforms.rocm.on_gfx1x", lambda: False)
|
||||||
|
monkeypatch.setattr("vllm.platforms.rocm.on_gfx9", lambda: False)
|
||||||
|
monkeypatch.setattr("vllm.platforms.rocm.on_gfx950", lambda: True)
|
||||||
|
monkeypatch.setattr(utils, "get_cu_count", lambda: 120)
|
||||||
|
|
||||||
|
wvsplitkrc_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
|
||||||
|
monkeypatch.setattr(utils.ops, "wvSplitKrc", wvsplitkrc_mock)
|
||||||
|
wvsplitk_mock = MagicMock(side_effect=lambda w, x_view, _, __: x_view @ w.t())
|
||||||
|
monkeypatch.setattr(utils.ops, "wvSplitK", wvsplitk_mock)
|
||||||
|
|
||||||
|
out = utils.rocm_unquantized_gemm_impl(x, weight, None)
|
||||||
|
ref = torch.nn.functional.linear(x, weight, None)
|
||||||
|
|
||||||
|
wvsplitkrc_mock.assert_called_once()
|
||||||
|
wvsplitk_mock.assert_not_called()
|
||||||
|
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
|
||||||
@@ -122,7 +122,7 @@ def use_aiter_triton_gemm(n, m, k, dtype):
|
|||||||
def rocm_unquantized_gemm_impl(
|
def rocm_unquantized_gemm_impl(
|
||||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.platforms.rocm import on_gfx9, on_gfx950
|
from vllm.platforms.rocm import on_gfx1x, on_gfx9, on_gfx950
|
||||||
|
|
||||||
n = x.numel() // x.size(-1)
|
n = x.numel() // x.size(-1)
|
||||||
m = weight.shape[0]
|
m = weight.shape[0]
|
||||||
@@ -169,12 +169,12 @@ def rocm_unquantized_gemm_impl(
|
|||||||
|
|
||||||
use_skinny = (
|
use_skinny = (
|
||||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||||
and on_gfx9()
|
and (on_gfx9() or on_gfx1x())
|
||||||
and x.dtype in [torch.float16, torch.bfloat16]
|
and x.dtype in [torch.float16, torch.bfloat16]
|
||||||
and k % 8 == 0
|
and k % 8 == 0
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_skinny is not True:
|
if not use_skinny:
|
||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
x_view = x.reshape(-1, x.size(-1))
|
x_view = x.reshape(-1, x.size(-1))
|
||||||
|
|||||||
Reference in New Issue
Block a user