diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 9e776296f..442b20e41 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -12,6 +12,7 @@ #include "../cuda_compat.h" #include "dispatch_utils.h" #include "quantization/w8a8/fp8/common.cuh" +#include "core/batch_invariant.hpp" // TODO(rasmith): The kernels in this file are susceptible to integer overflow // issues, do not take strides, and are unable to handle PyTorch tensors that @@ -1224,17 +1225,14 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, #if defined(__gfx950__) #define WVSPLITKRC_1KPASS template + int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC> __global__ void __launch_bounds__(WvPrGrp* THRDS) __attribute__((amdgpu_waves_per_eu(1, 1))) - wvSplitKrc_(const int actlN, const int K, const int M, const int Bx, - const int By, const scalar_t* __restrict__ B, - const scalar_t* __restrict__ A, - const scalar_t* __restrict__ BIAS, float* glbl, scalar_t* C, - const int CuCount) { - // Use upper half of glbl buffer for atomic reduce counting - int* cntr = (int*)(&glbl[M * N]); - + wvSplitKrc_(const int actlN, const int K, const int Kap, const int M, + const int Bx, const int By, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + const scalar_t* __restrict__ BIAS, float* glbl, int* cntr, + scalar_t* C, const int CuCount) { constexpr int NTILE = 16; constexpr int APAD = 1; constexpr int ASTRD = 64; @@ -1425,11 +1423,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff); for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) { __builtin_amdgcn_global_load_lds( - (int*)(&A[min__( - K * actlN - A_CHUNK, - kOffcp + K * (n / CHUNKK + - (N / CHUNKK) * (threadIdx.x / (64 / CHUNKK)) + - (threadIdx.y % sprdN)))]), + (int*)(&A[min__(Kap * actlN - A_CHUNK, + kOffcp + Kap * (n / CHUNKK + + (N / CHUNKK) * (threadIdx.x / + (64 / CHUNKK)) + + (threadIdx.y % sprdN)))]), (int*)(&s[(k + kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]), 16, 0, 0); @@ -1533,45 +1531,98 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } + union flt4 { + scalar8 s8; + float2 f2[2]; + float4 f4; + }; if (m + (threadIdx.x % 16) < M) { int my_cntr; int mindx = m + (threadIdx.x % 16); int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction scalar_t biases[N / NTILE / GrpsShrB][4] = {}; // Atomic add the output, read biases - for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) - for (uint32_t j = 0; j < 4; j++) { - // int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + - // (N / GrpsShrB) * (threadIdx.y % GrpsShrB); - // int adr = mindx + M * nindx; - int g_nindx = - j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; - int g_adr = g_mindx + M * g_nindx * 4; - atomicAdd(&glbl[g_adr], sum4[nt][0][j]); + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + int g_nindx = + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; + int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4; + if (DTRMNSTC) { + flt4 flt4_ = {.s8 = sum4[nt][0]}; + __hip_atomic_store((float2*)&glbl[g_adr + M * N * (m0 / Mmod)], + flt4_.f2[0], __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + __hip_atomic_store((float2*)&glbl[g_adr + 2 + M * N * (m0 / Mmod)], + flt4_.f2[1], __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } else { + for (uint32_t j = 0; j < 4; j++) + atomicAdd((&glbl[g_adr + j]), sum4[nt][0][j]); } + } + + __atomic_signal_fence(__ATOMIC_SEQ_CST); + asm volatile("s_waitcnt vmcnt(0)" ::: "memory"); + __atomic_signal_fence(__ATOMIC_SEQ_CST); + int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); int adr_ = mindx + M * nindx_ / 4; - // Update the complete counter my_cntr = atomicAdd(&cntr[adr_], 1); - float vals[N / NTILE / GrpsShrB][4] = {}; + + // make sure LDS is free for write out staging + if (DTRMNSTC) __syncthreads(); + + // Update the complete counter + flt4 vals[N / NTILE / GrpsShrB] = {}; // If we're the last k-shard, read back the value and convert... if (my_cntr + 1 == k_rnd) { - if (BIAS) - for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { - for (uint32_t j = 0; j < 4; j++) { - int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + - (N / GrpsShrB) * (threadIdx.y % GrpsShrB); - biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx]; + cntr[adr_] = 0; // clear for next round + if constexpr (DTRMNSTC) { + #pragma unroll + for (int ks = 0; ks < k_rnd; ks++) { + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + int g_nindx = + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; + int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4; + __builtin_amdgcn_global_load_lds( + (float4*)(&glbl[g_adr + M * N * ks]), + &(((float4*)s)[(threadIdx.y * THRDS) + ks * THRDS * 4 + + nt * THRDS * 4 * k_rnd]), + 16, 0, 0); } } - for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { - for (uint32_t j = 0; j < 4; j++) { - int g_nindx = - j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; - int g_adr = g_mindx + M * g_nindx * 4; - vals[nt][j] = glbl[g_adr]; + if (BIAS) + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx]; + } + } + asm volatile("s_waitcnt 0"); + for (int ks = 0; ks < k_rnd; ks++) { + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + float4 eval = ((float4*)s)[(threadIdx.x + threadIdx.y * THRDS) + + ks * THRDS * 4 + nt * THRDS * 4 * k_rnd]; + vals[nt].f4 += eval; + } } + } else { + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + int g_nindx = + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; + int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4; + vals[nt].f4 = *(float4*)(&glbl[g_adr]); + *(float4*)(&glbl[g_adr]) = {}; // clear out for next round + } + if (BIAS) + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx]; + } + } } __builtin_amdgcn_sched_barrier(0); for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { @@ -1581,11 +1632,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (nindx < actlN) { int adr = mindx + M * nindx; if constexpr (std::is_same_v) { - vals[nt][j] += __bfloat162float(biases[nt][j]); - C[adr] = __float2bfloat16(vals[nt][j]); + vals[nt].s8[j] += __bfloat162float(biases[nt][j]); + C[adr] = __float2bfloat16(vals[nt].s8[j]); } else { - vals[nt][j] += __half2float(biases[nt][j]); - C[adr] = __float2half(vals[nt][j]); + vals[nt].s8[j] += __half2float(biases[nt][j]); + C[adr] = __float2half(vals[nt].s8[j]); } } } @@ -1604,21 +1655,25 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitKrc_(const int actlN, const int K, const int M, - const int Bx, const int By, const scalar_t* B, - const scalar_t* __restrict__ A, + int UNRL, int N, int GrpsShrB, int CHUNKK, int DTRMNSTC> +__global__ void wvSplitKrc_(const int actlN, const int K, const int Kap, + const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, const scalar_t* __restrict__ BIAS, float* glbl, - // int* cntr, - scalar_t* C, const int CuCount){UNREACHABLE_CODE} + int* cntr, scalar_t* C, + const int CuCount){UNREACHABLE_CODE} #endif // defined(__HIP__GFX9__) TODO: Add NAVI support torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, const std::optional& in_bias, const int64_t CuCount) { - auto M_in = in_a.size(0); - auto N_in = in_b.size(0); - auto K_in = in_a.size(1); + int _DTRMNSTC = 1; // vllm::vllm_is_batch_invariant(); + + auto M_in = in_b.size(0); + auto N_in = in_a.size(0); + auto K_in = in_b.size(1); + auto Kap_in = in_a.stride(0); + auto Bx_in = (in_bias.has_value() && in_bias->numel() > 0) ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) @@ -1635,13 +1690,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, auto out_c = torch::empty( {N_in, M_in}, - torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + torch::TensorOptions().dtype(in_a.dtype()).device(in_a.device())); auto N_p2 = 1U << (32 - __builtin_clz(N_in - 1)); - auto axl_glbl = torch::empty( - {N_p2 + N_p2 / 4, M_in + M_in / 4}, - torch::TensorOptions().dtype(torch::kFloat32).device(in_b.device())); - axl_glbl.zero_(); // disable for FAST_UNSAFE_RDC_INIT dim3 grid(CuCount); @@ -1649,55 +1700,70 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // const int max_lds_len = get_lds_size() / 2; + // With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile), + // and each working on a 512-shard of K, how many CUs would we need? + int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512); + + // How many of 4 waves in a group can work on same 16 Ms at same time? First + // try to maximize this. This reduces the Ms each group works on, i.e. + // increasing the number of CUs needed. + int GrpsShrB = min(N_p2 / 16, 4); + + // Given the above, how many CUs would we need? + int CuNeeded = rndup_cus * GrpsShrB; + + if (CuNeeded > CuCount) throw std::runtime_error("Invalid wvSplitKrc size"); + + // Can we increase SplitK by shrinking the K-shared to 256? + int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1; + + static torch::Tensor axl_glbl = + torch::zeros( + 128 * 1024 * (_DTRMNSTC ? 12 : 1), + torch::TensorOptions().dtype(torch::kFloat32).device(in_a.device())) + .detach(); + static torch::Tensor axl_cntr = + torch::zeros( + 128 * 1024 * (_DTRMNSTC ? 12 : 1) / 4, + torch::TensorOptions().dtype(torch::kInt).device(in_a.device())) + .detach(); + auto glbl = axl_glbl.data_ptr(); + auto cntr = axl_cntr.data_ptr(); + #define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \ { \ dim3 block(64, 4); \ - wvSplitKrc_ \ - <<>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, glbl, c, CuCount); \ + if (_DTRMNSTC) \ + wvSplitKrc_ \ + <<>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \ + af4, bf4, biasf4, glbl, cntr, c, \ + CuCount); \ + else \ + wvSplitKrc_ \ + <<>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \ + af4, bf4, biasf4, glbl, cntr, c, \ + CuCount); \ } - AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitKrc", [&] { + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_a.scalar_type(), "wvSplitKrc", [&] { using fptype = typename scalar::type; - fptype* af4 = reinterpret_cast(in_a.data_ptr()); + const fptype* af4 = reinterpret_cast(in_a.data_ptr()); const fptype* bf4 = reinterpret_cast(in_b.data_ptr()); const fptype* biasf4 = (in_bias.has_value() && in_bias->numel() > 0) ? reinterpret_cast(in_bias->data_ptr()) : nullptr; fptype* c = reinterpret_cast(out_c.data_ptr()); - auto glbl = axl_glbl.data_ptr(); - - // With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile), - // and each working on a 512-shard of K, how many CUs would we need? - int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512); - - // How many of 4 waves in a group can work on same 16 Ms at same time? First - // try to maximize this. This reduces the Ms each group works on, i.e. - // increasing the number of CUs needed. - int GrpsShrB = min(N_p2 / 16, 4); - - // Given the above, how many CUs would we need? - int CuNeeded = rndup_cus * GrpsShrB; - - if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size"); - - // Can we increase SplitK by shrinking the K-shared to 256? - int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1; switch (N_p2) { case 16: WVSPLITKrc(16, 1, 1) break; case 32: - if (chunkk == 2) - WVSPLITKrc(32, 2, 2) else if (chunkk == 1) WVSPLITKrc(32, 2, 1) break; + if (chunkk == 2) WVSPLITKrc(32, 2, 2) else WVSPLITKrc(32, 2, 1) break; case 64: - if (chunkk == 2) - WVSPLITKrc(64, 4, 2) else if (chunkk == 1) WVSPLITKrc(64, 4, 1) break; + if (chunkk == 2) WVSPLITKrc(64, 4, 2) else WVSPLITKrc(64, 4, 1) break; case 128: - if (chunkk == 2) - WVSPLITKrc(128, 4, 2) else if (chunkk == 1) - WVSPLITKrc(128, 4, 1) break; + if (chunkk == 2) WVSPLITKrc(128, 4, 2) else WVSPLITKrc(128, 4, 1) break; default: throw std::runtime_error( "Unsupported N value: " + std::to_string(M_in) + "," + diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 1f55a597d..91b774c47 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -70,7 +70,6 @@ N_FACTORS_WVSPLITKRC = [ 117, 128, ] - K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8] M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16] @@ -123,10 +122,11 @@ def pad_fp8(weight): @pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("padded_a", [False, True]) @pytest.mark.parametrize("bias_mode", BIAS_MODES) @pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950") -def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): +def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, padded_a, bias_mode): torch.manual_seed(seed) cu_count = num_compute_units() @@ -141,7 +141,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): # Given the above, how many CUs would we need? CuNeeded = rndup_cus * GrpsShrB # candidate for atomic reduce count splitk? - fits_wvsplitkrc = CuNeeded <= cu_count + fits_wvsplitkrc = (N_p2 * m * ((k + 512 - 1) // 512)) <= 128 * 1024 * 12 + fits_wvsplitkrc &= CuNeeded <= cu_count if not fits_wvsplitkrc: pytest.skip("Too large for wvSplitKrc") @@ -151,6 +152,8 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): ) # normalize to avoid large output-bias deltas A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier + if padded_a: + A = pad_fp8(A) BIAS = None if bias_mode == 1: @@ -159,7 +162,7 @@ def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1 ref_out = torch.nn.functional.linear(A, B, BIAS) - out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS) + out = ops.wvSplitKrc(A, B, cu_count, BIAS) if xnorm: torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index d1e35f583..e46e4fd39 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -129,10 +129,6 @@ def rocm_unquantized_gemm_impl( k = weight.shape[1] cu_count = num_compute_units() - if use_aiter_triton_gemm(n, m, k, x.dtype): - from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 - - return gemm_a16w16(x, weight, bias) # Next ^2 of n N_p2 = 1 << (n - 1).bit_length() @@ -145,7 +141,10 @@ def rocm_unquantized_gemm_impl( # Given the above, how many CUs would we need? CuNeeded = rndup_cus * GrpsShrB # candidate for atomic reduce count splitk? - fits_wvsplitkrc = CuNeeded <= cu_count + fits_wvsplitkrc = ( + N_p2 * m * ((k + 512 - 1) // 512) + ) <= 128 * 1024 * 12 # deterministic + fits_wvsplitkrc &= CuNeeded <= cu_count use_skinny_reduce_counting = ( envs.VLLM_ROCM_USE_SKINNY_GEMM @@ -157,13 +156,16 @@ def rocm_unquantized_gemm_impl( and k > 512 and m % 16 == 0 and fits_wvsplitkrc - and x.is_contiguous() + and weight.is_contiguous() ) ) if use_skinny_reduce_counting: - x_view = x.reshape(-1, x.size(-1)) - out = ops.wvSplitKrc(weight, x_view, cu_count, bias) - return out.reshape(*x.shape[:-1], weight.shape[0]) + return ops.wvSplitKrc(x, weight, cu_count, bias) + + if use_aiter_triton_gemm(n, m, k, x.dtype): + from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 + + return gemm_a16w16(x, weight, bias) use_skinny = ( envs.VLLM_ROCM_USE_SKINNY_GEMM