diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp new file mode 100644 index 0000000..286387b --- /dev/null +++ b/csrc/apis/attention.hpp @@ -0,0 +1,77 @@ +#pragma once + +#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp" + +#include "layout.hpp" + +namespace deep_gemm::attention { + +static void fp8_gemm_nt_skip_head_mid(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::tuple &head_splits, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + if (fp8_requires_k_major()) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + } + + // D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto& [m , k ] = get_shape<2>(a.first); + const auto& [n , k_] = get_shape<2>(b.first); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Check head splits and N + const auto& [left, mid, right] = head_splits; + DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid); + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128)); + const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast); + const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast); + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + const auto& epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) { + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { + sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void register_apis(pybind11::module_& m) { + m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"), + py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); +} + +} // namespace deep_gemm::attention diff --git a/csrc/apis/einsum.hpp b/csrc/apis/einsum.hpp new file mode 100644 index 0000000..5b05dbf --- /dev/null +++ b/csrc/apis/einsum.hpp @@ -0,0 +1,115 @@ +#pragma once + +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/layout.hpp" + +#include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp" +#include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp" +#include "../jit_kernels/impls/smxx_cublaslt.hpp" + +namespace deep_gemm::einsum { + +static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, + const std::optional& c) { + // Currently FP32 only support the accumulated expression + if (d.scalar_type() == torch::kFloat) { + DG_HOST_ASSERT(c->data_ptr() == d.data_ptr() and c->sizes() == d.sizes() and c->strides() == d.strides()); + } else { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(not c.has_value()); + + const auto& workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32)); + DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(), + c10::cuda::getCurrentCUDAStream())); + bmk_bnk_mn(a, b, workspace, workspace); + + // This line has an implicit FP32-to-BF16 casting + d.copy_(workspace); + return; + } + + DG_HOST_ASSERT(a.is_contiguous()); + DG_HOST_ASSERT(b.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + + const auto& [s , m, k ] = get_shape<3>(a); + const auto& [s_, n, k_] = get_shape<3>(b); + DG_HOST_ASSERT(s == s_ and k == k_); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k); + } else if (arch_major == 10) { + sm100_bmn_bnk_mn_gemm(a, b, d, s, m, n, k); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) { + const auto& [b , h , r ] = get_shape<3>(A); + const auto& [h_, d , r_] = get_shape<3>(B); + const auto& [b_, h__, d_] = get_shape<3>(D); + DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__); + + DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1); + DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1); + DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1); + + cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d); +} + +static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) { + const auto& [b , h , d ] = get_shape<3>(A); + const auto& [h_, d_ , r ] = get_shape<3>(B); + const auto& [b_, h__, r_] = get_shape<3>(D); + DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__); + + DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1); + DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1); + DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1); + + cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d); +} + +static void einsum(const std::string& expr, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c) { + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + if (c.has_value()) { + DG_HOST_ASSERT(c->scalar_type() == torch::kFloat); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + } + + // Some hardcoded Einstein sum kernels + // TODO: support any expression + // TODO: canonicalize expression + if (expr == "bmk,bnk->mn") { + bmk_bnk_mn(a, b, d, c); + } else if (expr == "bhr,hdr->bhd") { + DG_HOST_ASSERT(not c.has_value()); + bhr_hdr_bhd(a, b, d); + } else if (expr == "bhd,hdr->bhr") { + DG_HOST_ASSERT(not c.has_value()); + bhd_hdr_bhr(a, b, d); + } else { + DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr)); + } +} + +static void register_apis(pybind11::module_& m) { + m.def("einsum", &einsum, + py::arg("expr"), py::arg("a"), py::arg("b"), + py::arg("d"), py::arg("c") = std::nullopt); +} + +} // namespace deep_gemm::einsum diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index a6bd344..8d06292 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -1,5 +1,6 @@ #pragma once +#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" #include "../jit_kernels/impls/sm90_bf16_gemm.hpp" #include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" @@ -52,13 +53,18 @@ static void fp8_gemm_nt(const std::pair& a, // Transform SFA and SFB into compute-required layout if (not recipe.has_value()) recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128)); const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast); const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast); // Dispatch into different implements const auto& arch_major = device_runtime->get_arch_major(); if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + if (std::get<1>(recipe.value()) == 1) { + sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { @@ -261,6 +267,60 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::tuple& recipe, + const std::string& compiled_dims) { + // Must be 1D1D kernel + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + + // Shape checks + const auto& [num_groups, m, n] = get_shape<3>(d); + const auto& sum_mk = a.first.numel(); + const auto& sum_nk = b.first.numel(); + int sum_k = 0; + for (const auto& k: ks) + sum_k += k; + DG_HOST_ASSERT(sum_mk == m * sum_k); + DG_HOST_ASSERT(sum_nk == n * sum_k); + + // Contiguity checks + DG_HOST_ASSERT(a.first.is_contiguous()); + DG_HOST_ASSERT(b.first.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + if (c.has_value()) { + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().is_contiguous()); + } + + // Do nothing if empty + if (std::accumulate(ks.begin(), ks.end(), 0) == 0) + return; + + // Transform SF with padding + const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe); + const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe); + + // Allocate tensormap buffer + // `4` means the double buffering for both A and B operands (2 * 2) + const auto& num_sms = device_runtime->get_num_sms(); + const auto& tensor_map_buffer = torch::empty({num_sms * 4 * static_cast(sizeof(CUtensorMap))}, + a.first.options().dtype(torch::kByte)); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer, + cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + static void bf16_gemm_nt(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, @@ -403,6 +463,43 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T } } +static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + + // Type and shape checks + const auto& [m , k ] = get_shape<2>(a); + const auto& [n , k_] = get_shape<2>(b); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + + if (c.has_value()) + DG_HOST_ASSERT(c.value().scalar_type() == d.scalar_type()); + + // Do nothing if the problem is empty + if (m == 0 or n == 0) + return; + + cublaslt_gemm(a, b, c, d, m, n, k, major_a, major_b); +} + +static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + cublaslt_gemm_nt(a, b.transpose(0, 1), d, c); +} + +static void cublaslt_gemm_tn(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + cublaslt_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c); +} + +static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + cublaslt_gemm_nt(a.transpose(0, 1), b, d, c); +} + static void register_apis(pybind11::module_& m) { // FP8 GEMMs m.def("fp8_gemm_nt", &fp8_gemm_nt, @@ -442,6 +539,11 @@ static void register_apis(pybind11::module_& m) { py::arg("ks_tensor"), py::arg("c") = std::nullopt, py::arg("recipe") = std::make_tuple(1, 1, 128), py::arg("compiled_dims") = "mn"); + m.def("k_grouped_fp8_gemm_nt_contiguous", &k_grouped_fp8_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), + py::arg("ks_tensor"), py::arg("c") = std::nullopt, + py::arg("recipe") = std::make_tuple(1, 1, 128), + py::arg("compiled_dims") = "mn"); // BF16 GEMMs m.def("bf16_gemm_nt", &bf16_gemm_nt, @@ -466,6 +568,16 @@ static void register_apis(pybind11::module_& m) { m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("expected_m"), py::arg("compiled_dims") = "nk"); + + // cuBLASLt GEMMs + m.def("cublaslt_gemm_nt", &cublaslt_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); + m.def("cublaslt_gemm_nn", &cublaslt_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); + m.def("cublaslt_gemm_tn", &cublaslt_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); + m.def("cublaslt_gemm_tt", &cublaslt_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); } } // namespace deep_gemm::gemm diff --git a/csrc/apis/layout.hpp b/csrc/apis/layout.hpp index 27c4120..a9cc0b1 100644 --- a/csrc/apis/layout.hpp +++ b/csrc/apis/layout.hpp @@ -56,14 +56,14 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te // FP32 on SM90 if (sf.scalar_type() == torch::kFloat and arch_major == 9) - DG_HOST_UNREACHABLE("Unimplemented"); + return get_mn_major_tma_aligned_tensor(sf); // FP32 on SM100 if (sf.scalar_type() == torch::kFloat and arch_major == 10) return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks); // INT on SM100 - if (sf.scalar_type() == torch::kFloat and arch_major == 10) + if (sf.scalar_type() == torch::kInt and arch_major == 10) DG_HOST_UNREACHABLE("Unimplemented"); DG_HOST_UNREACHABLE("Unknown cases"); diff --git a/csrc/indexing/main.cu b/csrc/indexing/main.cu index a05b59c..42c03a8 100644 --- a/csrc/indexing/main.cu +++ b/csrc/indexing/main.cu @@ -1,9 +1,16 @@ +// GEMM kernels #include -#include #include #include +#include #include #include + +// Einsum kernels +#include +#include + +// Layout kernels #include using namespace deep_gemm; diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index 79139d6..6ffd26f 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include "../utils/exception.hpp" @@ -11,8 +12,28 @@ class DeviceRuntime { int num_sms = 0, tc_util = 0; std::shared_ptr cached_prop; + // cuBLASLt utils + static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024; + cublasLtHandle_t cublaslt_handle{}; + std::shared_ptr cublaslt_workspace; + public: - explicit DeviceRuntime() = default; + explicit DeviceRuntime() { + cublaslt_workspace = std::make_shared(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA))); + DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle)); + } + + ~DeviceRuntime() noexcept(false) { + DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle)); + } + + cublasLtHandle_t get_cublaslt_handle() const { + return cublaslt_handle; + } + + torch::Tensor get_cublaslt_workspace() const { + return *cublaslt_workspace; + } std::shared_ptr get_prop() { if (cached_prop == nullptr) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 681e654..455223b 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include "../../utils/math.hpp" #include "../../utils/layout.hpp" @@ -80,18 +82,19 @@ static bool is_multicast_legal(const int& shape_dim, const int& block_dim, return divisible and num_sms % num_multicast == 0; } -static int get_swizzle_mode(const int& block_size, const int& elem_size) { +template +static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) { // `> 0` means interleaving // 16B actually means non-swizzling (but interleaving) for (const int& mode: {128, 64, 32, 16}) { - if ((block_size * elem_size) % mode == 0) + if ((block_size * static_cast(elem_size)) % mode == 0) return mode; } DG_HOST_UNREACHABLE("Unreachable"); } template -static SharedMemoryConfig get_smem_config(const KernelType& kernel_type, +static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const KernelType& kernel_type, const int& m, const int& n, const int& k, const int& block_m, const int& block_n, const int& block_k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, @@ -104,7 +107,7 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type, const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n); const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size); const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size); - const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size); + const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0; // Different archs have different epilogue pipelines const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype); @@ -121,9 +124,11 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type, // M-barriers and tensor memory pointers const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages); const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size(); + const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type); // Sum them up int smem_size = 0; + smem_size += smem_tensor_map; smem_size += smem_cd; smem_size += num_stages * smem_a_per_stage; smem_size += num_stages * smem_b_per_stage; @@ -151,15 +156,12 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); // Select M/N block sizes - // TODO: support `% 16 == 8` block size on SM90 auto block_ms = std::vector{64, 128, 256}; if (gemm_type == GemmType::MGroupedContiguous) block_ms = std::vector{get_mk_alignment_for_contiguous_layout()}; if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance block_ms = std::vector{64, 128}; - std::vector block_ns; - for (int i = 16; i <= 256; i += 16) - block_ns.push_back(i); + const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype); // K block size is selected in a fixed manner const auto& block_k = 128 / static_cast(c10::elementSize(ab_dtype)); @@ -214,9 +216,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0); // Decide the number of TMA multicasts and whether broadcast on A - MulticastConfig best_multicast_config = {1, true}; + MulticastConfig best_multicast_config = {1, false}; const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( - gemm_type, m, n, best_block_m, best_block_n, num_sms); + gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms); const bool is_legal[2] = {is_legal_on_b, is_legal_on_a}; bool order[2] = {false, true}; if (best_block_m > best_block_n) @@ -232,11 +234,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k constexpr int smem_capacity = ArchSpec::smem_capacity; int best_num_stages = 0; SharedMemoryConfig best_smem_config; - for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) { + for (int num_stages = 12; num_stages > 0; -- num_stages) { if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k)) continue; - best_smem_config = get_smem_config(kernel_type, + best_smem_config = get_smem_config(gemm_type, kernel_type, m, n, k, best_block_m, best_block_n, block_k, major_a, major_b, diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index 0679cad..e62a13c 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -12,6 +12,15 @@ namespace deep_gemm { struct SM100ArchSpec { static constexpr int smem_capacity = 232448; + static std::vector get_block_n_candidates(const at::ScalarType& cd_dtype) { + // 16 is for better SM usage + // Stride 32 is due to low-performance swizzle-16/32B + std::vector candidates = {16}; + for (int i = 32; i <= 256; i += 32) + candidates.push_back(i); + return candidates; + } + static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) { return block_m / (config.is_multicast_on_a ? config.num_multicast : 1); } @@ -29,6 +38,10 @@ struct SM100ArchSpec { return block_n; } + static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) { + return true; + } + static std::pair get_sf_uttcp_aligned_block_sizes( const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) { constexpr int num_utccp_aligned_elems = 128; @@ -86,7 +99,7 @@ struct SM100ArchSpec { return false; } - static std::pair get_multicast_legality(const GemmType& gemm_type, + static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, const int& m, const int& n, const int& block_m, const int& block_n, const int& num_sms) { // TODO: support other layouts @@ -138,12 +151,17 @@ struct SM100ArchSpec { // TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers // NOTES: 1D2D kernel will not use the with-SF full barriers // NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages - return num_stages * 8 * 3 + 2 * 8 * 2; + // NOTES: the last barrier is for tensor core utilization control + return num_stages * 8 * 3 + 2 * 8 * 2 + 8; } static int get_tmem_ptr_smem_size() { return 4; } + + static int get_tensormap_smem_size(const GemmType& gemm_type) { + return 0; + } }; } // namespace deep_gemm diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index 58faecf..133e2da 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -11,6 +11,15 @@ namespace deep_gemm { struct SM90ArchSpec { static constexpr int smem_capacity = 232448; + static std::vector get_block_n_candidates(const at::ScalarType& cd_dtype) { + // Avoid bank conflicts for FP32 output + const auto& start = cd_dtype == torch::kFloat ? 8 : 16; + std::vector candidates; + for (int i = start; i <= 256; i += 16) + candidates.push_back(i); + return candidates; + } + static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) { return block_m; } @@ -19,26 +28,35 @@ struct SM90ArchSpec { return block_n; } - static int get_cd_store_block_m(const int& block_m) { - return block_m; + static int get_cd_store_block_m(const int& block_m, const bool& single_warpgroup_sync = false) { + constexpr int wgmma_m = 64; + return single_warpgroup_sync ? wgmma_m : block_m; } static int get_cd_store_block_n(const int& block_n) { return block_n; } + static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) { + return cd_dtype != torch::kFloat; + } + static bool is_block_size_legal(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, const int& block_m, const int& block_n, const int& block_k) { - // FP32 output does not support `block_m == 256` + // SM90 FP32 output does not support `block_m == 256` if (cd_dtype == at::kFloat and block_m == 256) return false; - // TODO: more general block N selection - // Must be some fixed block N selections - if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 and block_n != 152)) - return false; + // Avoid large C/D shared memory for FP32 output + // Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel) + if (block_n > 128 and cd_dtype == torch::kFloat) { + if (kernel_type == KernelType::Kernel1D1D and block_n > 152) + return false; + if (kernel_type == KernelType::KernelNoSF and block_n > 200) + return false; + } // Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k` // Or too many register spills @@ -66,9 +84,13 @@ struct SM90ArchSpec { return true; } - static std::pair get_multicast_legality(const GemmType& gemm_type, + static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, const int& m, const int& n, const int& block_m, const int& block_n, const int& num_sms) { + // Disable multicast when the number of k-groups is large (a heuristic) + if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4) + return {false, false}; + return { is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked), // For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even @@ -96,9 +118,10 @@ struct SM90ArchSpec { int smem_sfa_per_stage = block_m * static_cast(sizeof(float)); int smem_sfb_per_stage = 0; - // TODO: figure out here - if (kernel_type == KernelType::Kernel1D1D) - smem_sfb_per_stage = align(block_n * 4, block_k); + if (kernel_type == KernelType::Kernel1D1D) { + // NOTES: `128` is for 2D TMA alignment requirement + smem_sfb_per_stage = align(block_n * 4, 128); + } return {smem_sfa_per_stage, smem_sfb_per_stage}; } @@ -109,13 +132,16 @@ struct SM90ArchSpec { } static int get_barrier_smem_size(const int& num_stages) { - // For 1D1D kernels, there is an extra barrier for accumulation - return (num_stages + 1) * 8 * 2; + return num_stages * 8 * 2; } static int get_tmem_ptr_smem_size() { return 0; } + + static int get_tensormap_smem_size(const GemmType& gemm_type) { + return gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast(sizeof(CUtensorMap)) : 0; + } }; } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/epilogue.hpp b/csrc/jit_kernels/impls/epilogue.hpp new file mode 100644 index 0000000..bd21de1 --- /dev/null +++ b/csrc/jit_kernels/impls/epilogue.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +static std::string get_default_epilogue_type(const std::optional& epilogue_type) { + return epilogue_type.value_or("EpilogueIdentity"); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index ed9c530..92685ef 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -4,6 +4,8 @@ #include #include "../../utils/math.hpp" +#include "../heuristics/sm90.hpp" +#include "../../utils/system.hpp" #include "../../utils/exception.hpp" namespace deep_gemm { @@ -51,7 +53,11 @@ static std::string to_string(const at::ScalarType& dtype) { } } -static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) { +static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype, + const bool& allow_tf32) { + if (allow_tf32 and dtype == torch::kFloat) + return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; + switch (dtype) { case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32; case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; @@ -61,9 +67,14 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& } } -static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) { +static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) { + if (base != 0) { + DG_HOST_ASSERT(base == 32 and mode == 128); + return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B; + } + switch (mode) { - case 0: return CU_TENSOR_MAP_SWIZZLE_NONE; + case 0: case 16: return CU_TENSOR_MAP_SWIZZLE_NONE; case 32: return CU_TENSOR_MAP_SWIZZLE_32B; case 64: return CU_TENSOR_MAP_SWIZZLE_64B; @@ -76,7 +87,8 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, int gmem_inner_dim, int gmem_outer_dim, int smem_inner_dim, int smem_outer_dim, const int& gmem_outer_stride, - const int& swizzle_mode) { + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { const auto& elem_size = static_cast(t.element_size()); if (swizzle_mode != 0) smem_inner_dim = swizzle_mode / elem_size; @@ -87,14 +99,42 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, const cuuint64_t gmem_strides[1] = {static_cast(gmem_outer_stride * elem_size), }; const cuuint32_t elem_strides[2] = {1, 1}; if (get_env("DG_JIT_DEBUG")) { - printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n", + printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d\n", gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, - gmem_outer_stride, swizzle_mode, elem_size); + gmem_outer_stride, swizzle_mode, swizzle_base, elem_size); } DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( - &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()), + &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32), 2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, - CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode), + CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base), + CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensor_map; +} + +static CUtensorMap make_tma_3d_desc(const torch::Tensor& t, + const int& gmem_dim_0, const int& gmem_dim_1, const int& gmem_dim_2, + const int& smem_dim_0, const int& smem_dim_1, const int& smem_dim_2, + const int& gmem_stride_0, const int& gmem_stride_1, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + const auto& elem_size = static_cast(t.element_size()); + if (swizzle_mode != 0) + DG_HOST_ASSERT(smem_dim_0 == swizzle_mode / elem_size); + + CUtensorMap tensor_map; + const cuuint64_t gmem_dims[3] = {static_cast(gmem_dim_0), static_cast(gmem_dim_1), static_cast(gmem_dim_2),}; + const cuuint32_t smem_dims[3] = {static_cast(smem_dim_0), static_cast(smem_dim_1), static_cast(smem_dim_2)}; + const cuuint64_t gmem_strides[2] = {static_cast(gmem_stride_0 * elem_size), static_cast(gmem_stride_1 * elem_size)}; + const cuuint32_t elem_strides[3] = {1, 1, 1}; + if (get_env("DG_JIT_DEBUG")) { + printf("Making 3D TMA desc: global memory: %d %d %d, shared memory: %d %d %d, outer stride: %d %d, swizzle: %d, elem size: %d\n", + gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2, + gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size); + } + DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32), + 3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, + CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base), CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); return tensor_map; } @@ -105,7 +145,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major, const int& block_m, const int& block_k, const int& outer_stride, const int& num_groups, - const int& swizzle_mode) { + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { if (num_groups > 1) DG_HOST_ASSERT(major == cute::UMMA::Major::K); const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups); @@ -114,7 +155,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major, gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, outer_stride, - swizzle_mode); + swizzle_mode, swizzle_base, + allow_tf32); } static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major, @@ -123,7 +165,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major, const int& block_n, const int& block_k, const int& outer_stride, const int& num_groups, - const int& swizzle_mode) { + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n); const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n); @@ -132,7 +175,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major, gmem_inner_dim, gmem_outer_dim * num_groups, smem_inner_dim, smem_outer_dim, outer_stride, - swizzle_mode); + swizzle_mode, swizzle_base, + allow_tf32); } static CUtensorMap make_tma_cd_desc(const torch::Tensor& t, @@ -140,15 +184,16 @@ static CUtensorMap make_tma_cd_desc(const torch::Tensor& t, const int& block_m, const int& block_n, const int& outer_stride, const int& num_groups, - const int& swizzle_mode) { - + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { // Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode` // bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required return make_tma_2d_desc(t, shape_n, shape_m * num_groups, block_n, block_m, outer_stride, - swizzle_mode); + swizzle_mode, swizzle_base, + allow_tf32); } static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, @@ -156,7 +201,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, int shape_mn, int shape_k, const int& block_mn, const int& block_k, const int& num_groups, - const int& swizzle_mode) { + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { DG_HOST_ASSERT(major == cute::UMMA::Major::MN); // TODO: maybe swizzle SF as well @@ -167,7 +213,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups, block_mn, 1, shape_mn, - swizzle_mode); + swizzle_mode, swizzle_base, + allow_tf32); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp index 033a7b7..e913d80 100644 --- a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -42,7 +42,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, {}, - {}, {}, + {}, {}, {}, {}, {}, {}, @@ -56,7 +56,7 @@ static void __instantiate_kernel() {{ args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, args.num_groups, args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, - args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.num_stages, args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, args.gemm_config.num_sms, @@ -80,8 +80,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { - // TODO: test other Ks - DG_HOST_ASSERT(k % 64 == 0); + const auto& aligned_k = align(k, 64); const auto& config = get_best_config( GemmType::Normal, KernelType::KernelNoSF, m, n, k, 1, major_a, major_b, @@ -122,7 +121,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a, // Launch const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, + .m = m, .n = n, .k = aligned_k, .num_groups = 1, .compiled_dims = compiled_dims, .gemm_config = config, diff --git a/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp b/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp new file mode 100644 index 0000000..5503d51 --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp @@ -0,0 +1,137 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BmkBnkMnRuntime final: public LaunchRuntime { +public: + struct Args { + int s, m, n, k; + int block_m, block_n, block_k; + int split_factor; + int swizzle_ab_mode, swizzle_cd_mode; + int num_stages; + int num_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_bmn_bnk_mn_gemm_impl< + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, + {}, {} + >); +}}; +)", + args.m, args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.split_factor, + args.swizzle_ab_mode, args.swizzle_cd_mode, + args.num_stages, args.num_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.s, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d)); + } +}; + + +static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a, + const torch::Tensor &b, + const torch::Tensor &d, + const int &s, const int &m, const int &n, const int &k) { + constexpr int block_m = 128; + constexpr int block_n = 128; + constexpr int block_k = 64; + constexpr int num_threads = 128; + DG_HOST_ASSERT(k % block_k == 0); + DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0); + DG_HOST_ASSERT(static_cast(s) * static_cast(std::max(m, n)) <= std::numeric_limits::max()); + + const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast(a.element_size())); + const int swizzle_cd_mode = get_swizzle_mode(block_n, static_cast(d.element_size())); + + // Get best config + const int num_sms = device_runtime->get_num_sms(); + const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n); + const int num_sk_blocks = s * (k / block_k); + const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1)); + + // Select best number of stages + // NOTES: we select 4 as start, as it is tested to be faster than values > 4 + int num_stages = 4, smem_size = 0; + while (true) { + const int& smem_cd = block_m * swizzle_cd_mode * 2; + const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t); + const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t); + const int& smem_barrier = SM100ArchSpec::get_barrier_smem_size(num_stages); + const int& smem_tmem_ptr = SM100ArchSpec::get_tmem_ptr_smem_size(); + + smem_size = 0; + smem_size += smem_cd; + smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages; + smem_size += smem_barrier; + smem_size += smem_tmem_ptr; + if (smem_size <= SM100ArchSpec::smem_capacity) + break; + + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("S: %d, M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split-K factor: %d" + "stages: %d, shared memory: %d, swizzle AB: %d, swizzle CD: %d\n", + s, m, n, k, block_m, block_n, block_k, split_factor, + num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode); + } + + const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); + const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); + const auto& tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode); + + const SM100BmkBnkMnRuntime::Args& args = { + .s = s, .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .split_factor = split_factor, + .swizzle_ab_mode = swizzle_ab_mode, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_threads = num_threads, + .launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100BmkBnkMnRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code); + SM100BmkBnkMnRuntime::launch(runtime, args); +} + +} // namespace deep_gemm \ No newline at end of file diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 67272d9..de4964f 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -9,6 +9,8 @@ #include "../../utils/format.hpp" #include "../../utils/math.hpp" #include "../heuristics/sm100.hpp" + +#include "epilogue.hpp" #include "runtime_utils.hpp" namespace deep_gemm { @@ -18,6 +20,7 @@ public: struct Args { int m, n, k, num_groups; const std::string& compiled_dims; + const std::optional& epilogue_type; GemmConfig gemm_config; LaunchArgs launch_args; @@ -44,11 +47,12 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, {}, - {}, {}, + {}, {}, {}, {}, {}, {}, - {}, {}, {} + {}, {}, {}, + {} >); }}; )", @@ -57,11 +61,12 @@ static void __instantiate_kernel() {{ args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, args.num_groups, args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, - args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.num_stages, args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, args.gemm_config.num_sms, - to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype)); + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype), + get_default_epilogue_type(args.epilogue_type)); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -80,7 +85,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa const torch::Tensor& d, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::Normal, KernelType::Kernel1D1D, @@ -99,7 +105,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa config.block_k, static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast(d.size(-1)), SM100ArchSpec::get_cd_store_block_m(config.block_m), SM100ArchSpec::get_cd_store_block_n(config.block_n), static_cast(d.stride(-2)), 1, @@ -129,6 +135,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa .m = m, .n = n, .k = aligned_k, .num_groups = 1, .compiled_dims = compiled_dims, + .epilogue_type = epilogue_type, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, @@ -186,6 +193,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con .m = m, .n = n, .k = aligned_k, .num_groups = num_groups, .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, @@ -244,6 +252,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t .m = m, .n = n, .k = aligned_k, .num_groups = num_groups, .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, @@ -324,6 +333,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& .m = m, .n = n, .k = sum_k, .num_groups = num_groups, .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp index 727d1b7..bc272ae 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp @@ -18,6 +18,7 @@ public: struct Args { int m, n, k, num_groups; const std::string& compiled_dims; + const std::optional& epilogue_type; GemmConfig gemm_config; LaunchArgs launch_args; @@ -46,7 +47,8 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, - {}, {} + {}, {}, + {} >); }}; )", @@ -59,7 +61,8 @@ static void __instantiate_kernel() {{ args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, args.gemm_config.num_sms, - to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype)); + to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype), + get_default_epilogue_type(args.epilogue_type)); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -78,7 +81,8 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa const torch::Tensor& d, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { DG_HOST_ASSERT(not c.has_value()); const auto& aligned_k = align(k, 128); @@ -98,7 +102,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa config.block_k, static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast(d.size(-1)), SM100ArchSpec::get_cd_store_block_m(config.block_m), SM100ArchSpec::get_cd_store_block_n(config.block_n), static_cast(d.stride(-2)), 1, @@ -111,6 +115,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa .m = m, .n = n, .k = aligned_k, .num_groups = 1, .compiled_dims = compiled_dims, + .epilogue_type = epilogue_type, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, @@ -164,6 +169,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con .m = m, .n = n, .k = aligned_k, .num_groups = num_groups, .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, @@ -218,6 +224,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const t .m = m, .n = n, .k = aligned_k, .num_groups = num_groups, .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp index ea29883..7b4c4f6 100644 --- a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -41,7 +41,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, - {}, {} + {}, {}, {} >); }}; )", @@ -53,7 +53,8 @@ static void __instantiate_kernel() {{ args.gemm_config.num_stages, args.gemm_config.num_last_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), + to_string(args.gemm_config.cd_dtype)); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -73,10 +74,10 @@ static void sm90_bf16_gemm(const torch::Tensor& a, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { - DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(not c.has_value()); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - DG_HOST_ASSERT(k % 64 == 0); + const auto& aligned_k = align(k, 64); const auto& config = get_best_config( GemmType::Normal, KernelType::KernelNoSF, m, n, k, 1, major_a, major_b, @@ -102,7 +103,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a, // Launch const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = k, + .m = m, .n = n, .k = aligned_k, .num_groups = 1, .compiled_dims = compiled_dims, .gemm_config = config, diff --git a/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp b/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp new file mode 100644 index 0000000..ccaea7f --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp @@ -0,0 +1,131 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BmkBnkMnRuntime final: public LaunchRuntime { +public: + struct Args { + int s, m, n, k; + int block_m, block_n, block_k; + int split_factor; + int num_stages; + int num_tma_threads, num_math_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + float* d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_bmn_bnk_mn_gemm_impl< + {}, {}, {}, + {}, {}, {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.m, args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.split_factor, + args.num_stages, + args.num_tma_threads, args.num_math_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.s, args.tensor_map_a, args.tensor_map_b, args.d)); + } +}; + + +static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a, + const torch::Tensor &b, + const torch::Tensor &d, + const int &s, const int &m, const int &n, const int &k) { + constexpr int block_m = 128; + constexpr int block_n = 128; + constexpr int block_k = 64; + constexpr int num_tma_threads = 128; + constexpr int num_math_threads = 256; + DG_HOST_ASSERT(k % block_k == 0); + DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0); + DG_HOST_ASSERT(static_cast(s) * static_cast(std::max(m, n)) <= std::numeric_limits::max()); + + const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast(a.element_size())); + DG_HOST_ASSERT(swizzle_ab_mode == 128); + + // Get best config + const int num_sms = device_runtime->get_num_sms(); + const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n); + const int num_sk_blocks = s * (k / block_k); + const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1)); + + // Select best number of stages + int num_stages = 4, smem_size = 0; + while (true) { + const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t); + const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t); + const int& smem_barrier = SM90ArchSpec::get_barrier_smem_size(num_stages); + + smem_size = 0; + smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages; + smem_size += smem_barrier; + + if (smem_size <= SM90ArchSpec::smem_capacity) + break; + + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("S: %d, M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split-K factor: %d" + "stages: %d, shared memory: %d, swizzle AB: %d\n", + s, m, n, k, block_m, block_n, block_k, split_factor, + num_stages, smem_size, swizzle_ab_mode); + } + + const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); + const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); + + const SM90BmkBnkMnRuntime::Args& args = { + .s = s, .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .split_factor = split_factor, + .num_stages = num_stages, + .num_tma_threads = num_tma_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .d = d.data_ptr() + }; + const auto& code = SM90BmkBnkMnRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code); + SM90BmkBnkMnRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp new file mode 100644 index 0000000..2f54a35 --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp @@ -0,0 +1,214 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *gmem_a_ptr; + void *gmem_b_ptr; + void *grouped_layout; + void *tensor_map_buffer; + CUtensorMap tensor_map_a_base; + CUtensorMap tensor_map_b_base; + CUtensorMap tensor_map_sfa; + CUtensorMap tensor_map_sfb; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_gemm_1d1d_impl< + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {} + >); +}}; +)", + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.num_groups, + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.num_stages, + args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), + to_string(args.gemm_config.cd_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.gmem_a_ptr, args.gmem_b_ptr, + args.grouped_layout, + args.tensor_map_buffer, + args.m, args.n, args.k, + args.tensor_map_a_base, args.tensor_map_b_base, + args.tensor_map_sfa, args.tensor_map_sfb, + args.tensor_map_d)); + } +}; + +static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D1D, + m, n, k, 1, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, k, 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, k, 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, 1, 0); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m, true), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + 0); + + // Launch + const SM90FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .gmem_a_ptr = nullptr, + .gmem_b_ptr = nullptr, + .grouped_layout = nullptr, + .tensor_map_buffer = nullptr, + .tensor_map_a_base = tensor_map_a, + .tensor_map_b_base = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_d = tensor_map_d, + }; + const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); + + SM90FP8Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const torch::Tensor& tensor_map_buffer, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + // Get config using max K for better performance + const auto& num_groups = static_cast(ks.size()); + const auto& max_k = *std::max_element(ks.begin(), ks.end()); + const auto& config = get_best_config( + GemmType::KGroupedContiguous, KernelType::Kernel1D1D, + m, n, max_k, num_groups, major_a, major_b, + torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + + int first_k = 0, sum_k = 0, sum_sf_k = 0; + for (int i = 0; i < num_groups; ++ i) { + if (first_k == 0 and ks[i] != 0) + first_k = ks[i]; + sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128); + DG_HOST_ASSERT(ks[i] % 128 == 0); + } + const auto& tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, first_k, 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, first_k, 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128, + config.block_n, config.block_k, 1, 0); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m, true), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = sum_k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .gmem_a_ptr = a.data_ptr(), + .gmem_b_ptr = b.data_ptr(), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_buffer = tensor_map_buffer.data_ptr(), + .tensor_map_a_base = tensor_map_a_base, + .tensor_map_b_base = tensor_map_b_base, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_d = tensor_map_d, + }; + const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); + + SM90FP8Gemm1D1DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 3afc2d3..ac87860 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -8,6 +8,8 @@ #include "../../utils/exception.hpp" #include "../../utils/format.hpp" #include "../heuristics/sm90.hpp" + +#include "epilogue.hpp" #include "runtime_utils.hpp" namespace deep_gemm { @@ -17,6 +19,7 @@ public: struct Args { int m, n, k, num_groups; const std::string& compiled_dims; + const std::optional& epilogue_type; GemmConfig gemm_config; LaunchArgs launch_args; @@ -43,7 +46,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, - {}, {} + {}, {}, {} >); }}; )", @@ -55,7 +58,8 @@ static void __instantiate_kernel() {{ args.gemm_config.num_stages, args.gemm_config.num_last_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type), + get_default_epilogue_type(args.epilogue_type)); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -74,7 +78,8 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& d, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); @@ -98,7 +103,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, config.block_k, static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, config.smem_config.swizzle_b_mode); - const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast(d.size(-1)), SM90ArchSpec::get_cd_store_block_m(config.block_m), SM90ArchSpec::get_cd_store_block_n(config.block_n), static_cast(d.stride(-2)), 1, @@ -111,6 +116,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, .m = m, .n = n, .k = aligned_k, .num_groups = 1, .compiled_dims = compiled_dims, + .epilogue_type = epilogue_type, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, @@ -170,6 +176,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons .m = m, .n = n, .k = aligned_k, .num_groups = num_groups, .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, @@ -230,6 +237,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to .m = m, .n = n, .k = aligned_k, .num_groups = num_groups, .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, diff --git a/csrc/jit_kernels/impls/smxx_cublaslt.hpp b/csrc/jit_kernels/impls/smxx_cublaslt.hpp new file mode 100644 index 0000000..0881607 --- /dev/null +++ b/csrc/jit_kernels/impls/smxx_cublaslt.hpp @@ -0,0 +1,151 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm { + +static auto get_cublaslt_layout(const cudaDataType& type, const int& rows, const int& cols, const int& ld, + const std::optional& batch_count = std::nullopt, + const std::optional& batch_offset = std::nullopt) { + cublasLtMatrixLayout_t layout; + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&layout, type, rows, cols, ld)); + if (batch_count.has_value()) { + DG_HOST_ASSERT(batch_offset.has_value()); + + const int64_t batch_offset_int64 = batch_offset.value(); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count.value(), sizeof(batch_count.value()))); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset_int64, sizeof(batch_offset_int64))); + } + return layout; +} + +static void call_cublaslt_api(const cublasOperation_t& trans_a, + const cublasOperation_t& trans_b, + const cublasLtMatrixLayout_t& layout_a, + const cublasLtMatrixLayout_t& layout_b, + const cublasLtMatrixLayout_t& layout_d, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const bool& accumulate) { + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + cudaDataType_t scale_type = CUDA_R_32F; + const int& math_sms = device_runtime->get_num_sms(); + bool fp8_fast_accumulate = false; + + // Operation description + cublasLtMatmulDesc_t desc; + DG_CUBLASLT_CHECK(cublasLtMatmulDescCreate(&desc, compute_type, scale_type)); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a))); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b))); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms))); + if (a.scalar_type() == torch::kFloat8_e4m3fn) + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate))); + + // Get cuBLASLt handle, workspace, and stream + const auto& handle = device_runtime->get_cublaslt_handle(); + const auto& workspace = device_runtime->get_cublaslt_workspace(); + const auto& workspace_bytes = workspace.nbytes(); + const auto& stream = at::cuda::getCurrentCUDAStream(); + + // Algorithm selection + cublasLtMatmulPreference_t pref; + cublasLtMatmulHeuristicResult_t heuristic; + int num_heuristic_results = 0; + uint32_t reduction_scheme_mask = CUBLASLT_REDUCTION_SCHEME_NONE | CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE; + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&pref)); + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_bytes, sizeof(workspace_bytes))); + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, + &reduction_scheme_mask, sizeof(reduction_scheme_mask))); + DG_CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, layout_a, layout_b, layout_d, layout_d, + pref, 1, &heuristic, &num_heuristic_results)); + DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM"); + + // Call: D = alpha * (A @ B) + beta * C + const float& alpha = 1.0, beta = accumulate ? 1.0 : 0.0; + DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle + desc, // Operation description + &alpha, // Alpha + b.data_ptr(), layout_a, // A + a.data_ptr(), layout_b, // B + &beta, // Beta + d.data_ptr(), layout_d, // C + d.data_ptr(), layout_d, // D + &heuristic.algo, // Algorithm + workspace.data_ptr(), workspace_bytes, // Workspace + stream)); // Stream + + // Free memory + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceDestroy(pref)); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_a)); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_b)); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_d)); + DG_CUBLASLT_CHECK(cublasLtMatmulDescDestroy(desc)); +} + +static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs, + const std::optional& acc, + const torch::Tensor& out, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major) { + const auto& trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N; + const auto& trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T; + + // Duplicate the accumulator if necessary + // TODO: remove this + if (acc.has_value()) { + if (acc->data_ptr() == out.data_ptr()) { + DG_HOST_ASSERT(acc->sizes() == out.sizes() and acc->strides() == out.strides()); + } else { + out.copy_(acc.value()); + } + } + + // Matrix layouts + const auto& cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type()); + const auto& cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type()); + const auto& cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type()); + const auto& layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0)) + : get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1)); + const auto& layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0)) + : get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1)); + const auto& layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0)); + + call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, acc.has_value()); +} + + +static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out, + const int& b, const int& h, const int& r, const int& d) { + const auto& m = d, n = b, k = r; + const auto& trans_a = CUBLAS_OP_T; + const auto& trans_b = CUBLAS_OP_N; + + // Matrix layouts + const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0)); + const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1)); + const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1)); + + call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false); +} + + +static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out, + const int& b, const int& h, const int& r, const int& d) { + const auto& m = r, n = b, k = d; + const auto& trans_a = CUBLAS_OP_N; + const auto& trans_b = CUBLAS_OP_N; + + // Matrix layouts + const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0)); + const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1)); + const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1)); + + call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false); +} + +} // namespace deep_gemm diff --git a/csrc/python_api.cpp b/csrc/python_api.cpp index d4b210a..32983a2 100644 --- a/csrc/python_api.cpp +++ b/csrc/python_api.cpp @@ -1,6 +1,8 @@ #include #include +#include "apis/attention.hpp" +#include "apis/einsum.hpp" #include "apis/gemm.hpp" #include "apis/layout.hpp" #include "apis/runtime.hpp" @@ -13,6 +15,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "DeepGEMM C++ library"; + deep_gemm::attention::register_apis(m); + deep_gemm::einsum::register_apis(m); deep_gemm::gemm::register_apis(m); deep_gemm::layout::register_apis(m); deep_gemm::runtime::register_apis(m); diff --git a/csrc/utils/exception.hpp b/csrc/utils/exception.hpp index 57cc513..b61bc09 100644 --- a/csrc/utils/exception.hpp +++ b/csrc/utils/exception.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -72,4 +73,16 @@ do { \ } while (0) #endif +#ifndef DG_CUBLASLT_CHECK +#define DG_CUBLASLT_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != CUBLAS_STATUS_SUCCESS) { \ + std::ostringstream ss; \ + ss << static_cast(e) << " (" << cublasGetStatusString(e) << ")"; \ + throw DGException("cuBLASLt", __FILE__, __LINE__, ss.str()); \ + } \ +} while (0) +#endif + } // namespace deep_gemm diff --git a/csrc/utils/system.hpp b/csrc/utils/system.hpp index 91dee12..b0e28ba 100644 --- a/csrc/utils/system.hpp +++ b/csrc/utils/system.hpp @@ -1,11 +1,15 @@ #pragma once #include +#include +#include #include #include #include +#include #include "exception.hpp" +#include "format.hpp" namespace deep_gemm { @@ -65,7 +69,10 @@ static std::filesystem::path make_dirs(const std::filesystem::path& path) { // OK if existed std::error_code capture; const bool& created = std::filesystem::create_directories(path, capture); - DG_HOST_ASSERT(created or capture.value() == 0); + if (not (created or capture.value() == 0)) { + DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}", + path.c_str(), created, capture.value())); + } if (created and get_env("DG_JIT_DEBUG")) printf("Create directory: %s\n", path.c_str()); return path; diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 169e2e6..55d87da 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -25,15 +25,22 @@ from deep_gemm_cpp import ( # FP8 GEMMs fp8_gemm_nt, fp8_gemm_nn, fp8_gemm_tn, fp8_gemm_tt, + fp8_gemm_nt_skip_head_mid, m_grouped_fp8_gemm_nt_contiguous, m_grouped_fp8_gemm_nn_contiguous, m_grouped_fp8_gemm_nt_masked, + k_grouped_fp8_gemm_nt_contiguous, k_grouped_fp8_gemm_tn_contiguous, # BF16 GEMMs bf16_gemm_nt, bf16_gemm_nn, bf16_gemm_tn, bf16_gemm_tt, m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nt_masked, + # cuBLASLt GEMMs + cublaslt_gemm_nt, cublaslt_gemm_nn, + cublaslt_gemm_tn, cublaslt_gemm_tt, + # Einsum kernels + einsum, # Layout kernels transform_sf_into_required_layout ) diff --git a/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh b/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh new file mode 100644 index 0000000..5f6a7a1 --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +struct EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 + and kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/common/reduction.cuh b/deep_gemm/include/deep_gemm/common/reduction.cuh new file mode 100644 index 0000000..d9e35f7 --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/reduction.cuh @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include + +#include + +// Operation functors +template struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +__forceinline__ __device__ T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +__forceinline__ __device__ T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh index 2324a9b..237f688 100644 --- a/deep_gemm/include/deep_gemm/common/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -22,7 +22,6 @@ static constexpr uint32_t get_num_1d_blocks_per_group() { if (usage < min_usage) min_usage = usage, num_best_blocks = candidate; } - return num_best_blocks; } @@ -33,6 +32,7 @@ template ()> struct Scheduler { int current_iter = -1; @@ -48,30 +48,40 @@ struct Scheduler { // For grouped GEMM int* grouped_layout; - uint32_t current_group_idx; + uint32_t current_group_idx = 0; // Only used for masked layout - uint32_t current_m_cumsum; + uint32_t current_m_cumsum = 0; // Only used for k-grouped layout - uint32_t current_shape_k, current_num_valid_groups, current_k_cumsum, current_sf_k_cumsum; + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + __device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = __ldg(grouped_layout + group_idx); + if (shape_k > 0) + break; + } + } // ReSharper disable once CppPossiblyUninitializedMember - __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, + __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k, int* grouped_layout = nullptr) { num_m_blocks = ceil_div(shape_m, BLOCK_M); num_n_blocks = ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; if constexpr (kGemmType == GemmType::Normal) { num_blocks = num_m_blocks * num_n_blocks; } else if (kGemmType == GemmType::MGroupedContiguous) { num_blocks = num_m_blocks * num_n_blocks; this->grouped_layout = grouped_layout; } else if (kGemmType == GemmType::MGroupedMasked) { - current_group_idx = current_m_cumsum = 0; this->grouped_layout = grouped_layout; } else if (kGemmType == GemmType::KGroupedContiguous) { - current_group_idx = current_num_valid_groups = 0; - current_k_cumsum = current_sf_k_cumsum = 0; - current_shape_k = __ldg(grouped_layout + current_group_idx); this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); } } @@ -165,17 +175,17 @@ struct Scheduler { return false; // Within current group - if (current_shape_k > 0 and next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks) + if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks) break; // Move to check the next group - if (current_shape_k > 0) { - current_k_cumsum += current_shape_k; - current_sf_k_cumsum += ceil_div(current_shape_k, 512u); - current_num_valid_groups ++; - } - if ((++ current_group_idx) != kNumGroups) - current_shape_k = __ldg(grouped_layout + current_group_idx); + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); } get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx); @@ -197,7 +207,7 @@ struct Scheduler { __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { if (num_blocks_in_group == 1) return false; - if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked) { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::KGroupedContiguous) { return true; } else { DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); diff --git a/deep_gemm/include/deep_gemm/common/sm100_utils.cuh b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh index b208302..db1a364 100644 --- a/deep_gemm/include/deep_gemm/common/sm100_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh @@ -79,12 +79,24 @@ void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_p desc.start_address_ = static_cast(uint_ptr >> 4); } +__device__ __forceinline__ +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + // ReSharper disable once CppNotAllPathsReturnValue -template +template constexpr static cute::UMMA::LayoutType to_umma_layout_type() { DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or kSwizzleMode == 32 or kSwizzleMode == 64 or kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; @@ -104,10 +116,12 @@ uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, cons return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); } -template +template __device__ __forceinline__ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { const uint32_t stride_k = get_umma_desc_stride_k(); + const auto& layout_type = to_umma_layout_type(); + const auto& num_non_contiguous = 128 / get_atom_base(layout_type); if constexpr (kMajorMode == cute::UMMA::Major::K) { // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); @@ -115,9 +129,9 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id // Atom size: 8 x `kSwizzleMode` (in bytes, on K) // {SBO, LBO} means the byte stride between atoms on {MN, K} // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 - const uint32_t stride_byte_offset = 8 * BLOCK_K * sizeof(dtype_t); + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); const uint32_t leading_byte_offset = 0; - return make_smem_desc(to_umma_layout_type(), + return make_smem_desc(layout_type, base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, stride_byte_offset, leading_byte_offset); } else { @@ -132,11 +146,11 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling - uint32_t stride_byte_offset = 8 * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); if constexpr (kSwizzleMode == 16) swap(stride_byte_offset, leading_byte_offset); - return make_smem_desc(to_umma_layout_type(), + return make_smem_desc(layout_type, base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, stride_byte_offset, leading_byte_offset); } @@ -166,4 +180,81 @@ __device__ __forceinline__ void tcgen05_after_thread_sync() { asm volatile("tcgen05.fence::after_thread_sync;"); } +// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + } // namespace `deep_gemm::sm100` diff --git a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh index e590b47..d587f42 100644 --- a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -1,8 +1,12 @@ #pragma once +#include +#include #include #include +#include + namespace deep_gemm::sm90 { template @@ -29,6 +33,7 @@ struct FP8MMASelector { static constexpr auto select_mma() { using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); @@ -93,6 +98,7 @@ struct BF16MMASelector { static constexpr auto select_mma() { using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); @@ -144,6 +150,24 @@ struct SM90_U32x2_STSM_N { } }; +struct SM90_U32x2_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(smem_src)); + } +}; + +struct SM90_U32x4_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(smem_src)); + } +}; + __forceinline__ __device__ void warpgroup_arrive() { asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); } @@ -223,4 +247,37 @@ tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, } } +__device__ __forceinline__ void +tma_3d_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, + const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& crd_2) { + constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1, crd_2); +} + +// Tensormap related +__device__ __forceinline__ void tensor_map_release_cta() { + asm volatile ("fence.proxy.tensormap::generic.release.cta;"); +} + +__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 5))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version") +#endif +} + } // namespace `deep_gemm::sm90` diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index fc84b69..0b7ff11 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -104,6 +104,12 @@ __device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { return ret; } +__device__ __forceinline__ float2 ld_shared(const float2* ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr)); + return ret; +} + __device__ __forceinline__ float4 ld_shared(const float4* ptr) { float4 ret; asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr)); @@ -126,10 +132,18 @@ __device__ __forceinline__ void st_shared(const float* ptr, float val) { asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); } +__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y)); +} + __device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val)); } +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(ptr), "r"(x), "r"(y)); +} + __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(ptr), "r"(x), "r"(y), "r"(z), "r"(w)); } diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh index 46a668d..6b23535 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -17,7 +17,7 @@ template (); // Prefetch TMA descriptors at the very beginning - if (threadIdx.x == 0) { - // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + if (warp_idx == 0 and cute::elect_one_sync()) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_d); @@ -93,35 +92,31 @@ sm100_bf16_gemm_impl(int* grouped_layout, cute::prefetch_tma_descriptor(&tensor_map_c); } - // Data on shared memory (layout as ordered below) - cd_dtype_t* smem_cd[kNumTMAStoreStages]; - cutlass::bfloat16_t* smem_a[kNumStages]; - cutlass::bfloat16_t* smem_b[kNumStages]; - - // Fill D/A/B pointers - #pragma unroll - for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) - smem_cd[i] = reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - smem_a[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); - smem_b[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - } + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); // Fill barriers - auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + - kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2; // Fill the tensor memory pointer - auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1); DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); // Initialize barriers - if (threadIdx.x == 0) { + if (warp_idx == 1 and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumStages; ++ i) { // Arrive only at the leader CTA @@ -136,11 +131,12 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Arrive only at the leader CTA tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); } + if constexpr (kTensorCoreUtilControl < 100) + tensor_core_full_barrier->init(1); // Make initialized barrier visible in async proxy - cutlass::arch::fence_view_async_shared(); cutlass::arch::fence_barrier_init(); - } else if (threadIdx.x >= 32 and threadIdx.x < 64) { + } else if (warp_idx == 2) { // Allocate tensor memory Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); } @@ -148,100 +144,69 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); - // For pipeline unrolling - struct DivisibleK {}; - struct NotDivisibleK {}; - uint32_t phase = 0; - auto launch_k_iterations = [&](const auto& func) { - const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); - const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K); - const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages; + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; - // TODO: refactor here - if (num_last_stages == 0) { - for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1) - func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages); - } else { - for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1) - func(k_iter, DivisibleK{}, false, num_last_stages); - func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1; - } - }; - - auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) { - DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2, - "Too many epilogue stages, please modify the Python heuristic as well"); - accum_stage_idx == 0 ? func(0) : func(1); + // Flip phases only if reach the next first stage + stage_idx = (stage_idx + 1) % kNumStages; + phase ^= stage_idx == 0; }; // Dispatch warps into different roles - if (warp_idx == 0) { + if (warp_idx == 0 and cute::elect_one_sync()) { // TMA load warp // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + #pragma unroll + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait(phase ^ 1); + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); - // Compute offsets - // NOTES: the group is always concatenated with the outer dimension - uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> ( - shape_m, BLOCK_M, m_block_idx); - uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> ( - shape_n, BLOCK_N, n_block_idx, m_block_idx); + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); - // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major - // And for all m-grouped GEMMs, A must be K-majored - DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); - uint32_t k_block_idx = k_iter * kNumStages + s; - uint32_t k_idx = k_block_idx * BLOCK_K; - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> ( - shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> ( - shape_k, BLOCK_K, k_block_idx, m_block_idx); - - // Add 2 CTA offsets - if constexpr (kNumMulticast > 1) { - m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; - n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); - } - - // Issue TMAs - if (cute::elect_one_sync()) { - if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx); - if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx); - if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx); - if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx); - } - // Arrive at full barriers - constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; - if (is_leader_cta and cute::elect_one_sync()) - full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); - if (not is_leader_cta and cute::elect_one_sync()) - full_barriers[s]->arrive(0u); + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); } - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait(phase ^ 1); - if (is_leader_cta and cute::elect_one_sync()) - full_barriers[s]->arrive(); - if (not is_leader_cta and cute::elect_one_sync()) - full_barriers[s]->arrive(0u); + // Issue TMAs + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + } else { + full_barriers[stage_idx]->arrive(0u); } - }); + } } } else if (warp_idx == 1 and is_leader_cta) { // MMA issue warp @@ -268,88 +233,89 @@ sm100_bf16_gemm_impl(int* grouped_layout, // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { - // Wait tensor memory empty barrier arrival - auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; - tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // UMMA and empty barrier arrival alias + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); tcgen05_after_thread_sync(); - // Empty barrier arrival - auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) { - auto umma_arrive = [](const uint64_t* barrier) { - if constexpr (kNumMulticast == 1) { - cutlass::arch::umma_arrive(barrier); - } else { - constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; - cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); - } - }; - umma_arrive(reinterpret_cast(empty_barriers[s])); - - // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting - if (do_tmem_full_arrive) - umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); - }; - - // Launch MMAs - launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; - + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait TMA arrival - full_barriers[s]->wait(phase); - tcgen05_after_thread_sync(); - - // Let tensor cores relax for lower possibility of frequency drop - DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control"); - if constexpr (kTensorCoreUtilControl < 100) { - constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull; - constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; - const auto& start_clock = clock64(); - if (cute::elect_one_sync()) - while (clock64() - start_clock < kNumDummyCycles) {} - __syncwarp(); - } - - // Issue UMMA in the leader CTA - using cute_mma_t = cute::conditional_t, - cute::SM100_MMA_F16BF16_2x1SM_SS>; - const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); - const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s); - const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s); + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K); - cute_mma_t::fma(a_desc, b_desc, - accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, - k_iter > 0 or s > 0 or k > 0, - runtime_instr_desc); - } + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc); } - - // Commit to the mbarrier object - // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` - empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1); } + } - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait(phase); - empty_barrier_arrive(s, false); - } - }); - }); + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + + // Let tensor cores relax for lower possibility of frequency drop + DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control"); + if constexpr (kTensorCoreUtilControl < 100) { + // For utilization control + umma_arrive(reinterpret_cast(tensor_core_full_barrier)); + + // Wait for last UMMA to be done + tensor_core_full_barrier->wait(tensor_core_phase); + tensor_core_phase ^= 1; + + // Sleep for certain cycles + constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; + const auto& start_clock = clock64(); + if (cute::elect_one_sync()) + while (clock64() - start_clock < kNumDummyCycles) {} + __syncwarp(); + } + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); } } else if (warp_idx >= kNumNonEpilogueThreads / 32) { // Epilogue warp groups - const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads; const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, @@ -363,129 +329,114 @@ sm100_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { - auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; - // Flush TMA stores - // NOTES: for the first store, we have to flush all previous TMA, - // as we don't share pipeline stages between two blocks - if (epilogue_thread_idx == 0) - cute::tma_store_wait<0>(); - cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); - // Wait UMMA arrival - tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); - tcgen05_after_thread_sync(); + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); - // Load from tensor memory into registers, and write shared memory with STSM - DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); - DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); - - // Iterate over M waves + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - // Issue every swizzled atom and pipeline STSM and TMA store - constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory #pragma unroll - for (uint32_t s = 0; s < kNumStores; ++ s) { - // Wait shared memory to be released - const uint32_t iter_idx = w * kNumStores + s; - if (iter_idx >= kNumTMAStoreStages) { - if (epilogue_thread_idx == 0) - cute::tma_store_wait(); - cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); - } + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); - // The pipeline stage - const auto tma_stage_idx = iter_idx % kNumTMAStoreStages; - const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; - const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); - // Store into shared memory - #pragma unroll - for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - // Reshape the atom in another view and swizzle - // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` - // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` - // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern - constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (i) : (bank_group_index % 8); - col ^= row % (kSwizzleCDMode / 16); - - // Source and destination memory address - uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset - w * BLOCK_N + // Wave offset - s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset - auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer - epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - - // Load from tensor memory, store into shared memory - uint32_t values[kNumElemsPerBankGroup]; - if constexpr (cute::is_same_v) { - // For FP32 output, read and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, - values[0], values[1], values[2], values[3]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); - } else { - // For BF16 output, read, cast and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, - values[0], values[1], values[2], values[3], - values[4], values[5], values[6], values[7]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, - cast_into_bf16_and_pack(values[0], values[1]), - cast_into_bf16_and_pack(values[2], values[3]), - cast_into_bf16_and_pack(values[4], values[5]), - cast_into_bf16_and_pack(values[6], values[7])); - } - } - - // Notify tensor memory empty (only at the leader CTA) arrival ASAP - // NOTES: only the last stage needs to do this - if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { - tcgen05_before_thread_sync(); - tmem_empty_barriers[accum_stage_idx]->arrive(0u); - } - __syncwarp(); - - // Synchronize all threads and issue TMA - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); - if (epilogue_thread_idx == 0) { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); - cute::tma_store_arrive(); + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); } } - } - }); - } - // Flush all stages in the pipeline to make TMA stores visible to the next kernel - if (epilogue_thread_idx == 0) - cute::tma_store_wait<0>(); + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + } + } // Deallocate tensor memory by warp 1 // NOTES: warp 0 is waiting TMA store if (epilogue_warp_idx == 1) Allocator().free(0, kNumTmemCols); } - - // To safely deconstruct all barriers, we need a cluster sync - // TODO: optimize it by another round of barrier waits - if constexpr (kNumMulticast > 1) - cute::cluster_sync(); #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh new file mode 100644 index 0000000..8a0130b --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -0,0 +1,265 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumThreads, 1) +sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size"); + DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode"); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Shared memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Fill D/A/B + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE)); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Block indices + const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + if (warp_idx == 0) { + // TMA load warp + for (uint32_t s = 0; s < num_total_stages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + uint32_t m_idx = BLOCK_M * m_block_idx; + uint32_t n_idx = BLOCK_N * n_block_idx; + uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + uint32_t k_idx = sk_idx % SHAPE_K; + uint32_t s_idx = sk_idx / SHAPE_K; + + // Issue TMAs + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + } + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (cute::elect_one_sync()) + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } else if (warp_idx == 1) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + constexpr uint32_t UMMA_M = LAYOUT_AD_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Wait tensor memory empty barrier arrival + tcgen05_after_thread_sync(); + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + if (warp_idx == 2) + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + if (s >= kNumTMAStoreStages) { + if (warp_idx == 0 and cute::elect_one_sync()) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = s % kNumTMAStoreStages; + const auto m_idx = m_block_idx * BLOCK_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is doing TMA stores + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); +#endif +} + +} \ No newline at end of file diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 03c44cd..322c0fd 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -4,6 +4,7 @@ #include +#include #include #include #include @@ -17,11 +18,12 @@ template + GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t, + typename epilogue_type_t> __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_fp8_gemm_1d1d_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, @@ -96,8 +98,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; // Prefetch TMA descriptors at the very beginning - if (threadIdx.x == 0) { - // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + if (warp_idx == 0 and cute::elect_one_sync()) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_sfa); @@ -107,30 +108,25 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cute::prefetch_tma_descriptor(&tensor_map_c); } - // Data on shared memory (layout as ordered below) - cd_dtype_t* smem_cd[kNumTMAStoreStages]; - cutlass::float_e4m3_t* smem_a[kNumStages]; - cutlass::float_e4m3_t* smem_b[kNumStages]; - uint32_t* smem_sfa[kNumStages]; - uint32_t* smem_sfb[kNumStages]; + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); - // Fill D/A/B pointers - #pragma unroll - for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) - smem_cd[i] = reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - smem_a[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); - smem_b[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - } - - // Fill SFA/SFB + // SFA/SFB shared memory auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - smem_sfa[i] = reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); - smem_sfb[i] = reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); - } + auto smem_sfa = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); // Fill barriers auto barrier_start_ptr = reinterpret_cast(smem_buffer + @@ -148,7 +144,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); // Initialize barriers - if (threadIdx.x == 0) { + if (warp_idx == 1 and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumStages; ++ i) { // Arrive at all CTAs @@ -166,9 +162,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, } // Make initialized barrier visible in async proxy - cutlass::arch::fence_view_async_shared(); cutlass::arch::fence_barrier_init(); - } else if (threadIdx.x >= 32 and threadIdx.x < 64) { + } else if (warp_idx == 2) { // Allocate tensor memory Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); } @@ -176,108 +171,75 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); - // For pipeline unrolling - struct DivisibleK {}; - struct NotDivisibleK {}; - uint32_t phase = 0; - auto launch_k_iterations = [&](const auto& func) { - const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); - const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K); - const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages; + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; - // TODO: refactor here - if (num_last_stages == 0) { - for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1) - func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages); - } else { - for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1) - func(k_iter, DivisibleK{}, false, num_last_stages); - func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1; - } - }; - - auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) { - DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2, - "Too many epilogue stages, please modify the Python heuristic as well"); - accum_stage_idx == 0 ? func(0) : func(1); + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; }; // Dispatch warps into different roles - if (warp_idx == 0) { + if (warp_idx == 0 and cute::elect_one_sync()) { // TMA load warp // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait(phase ^ 1); + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); - // Compute offsets - // NOTES: the group is always concatenated with the outer dimension - uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> ( - shape_m, BLOCK_M, m_block_idx); - uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> ( - shape_n, BLOCK_N, n_block_idx, m_block_idx); + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); - // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major - // And for all m-grouped GEMMs, A must be K-majored - DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); - uint32_t k_block_idx = k_iter * kNumStages + s; - uint32_t k_idx = k_block_idx * BLOCK_K; - uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> ( - shape_k, BLOCK_K, k_block_idx, m_block_idx); - uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> ( - shape_k, BLOCK_K, k_block_idx, m_block_idx); - - // Add 2 CTA offsets - if constexpr (kNumMulticast > 1) { - m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; - n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); - } - - // Issue TMAs - if (cute::elect_one_sync()) { - if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx); - if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx); - if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx); - if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx); - } - auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; - - // Issue SFA and SFB TMAs at certain stages - // No swizzling, so one TMA for one SF is enough - const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad; - if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) { - tma_copy(&tensor_map_sfa, full_barriers[s], smem_sfa[s], m_block_idx * BLOCK_M, - scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad))); - tma_copy(&tensor_map_sfb, full_barriers[s], smem_sfb[s], n_block_idx * BLOCK_N, - scheduler.template get_global_idx(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx)); - num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t); - } - - // Arrive at full barriers - if (cute::elect_one_sync()) - full_barriers[s]->arrive_and_expect_tx(num_arrival_bytes); + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); } - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait(phase ^ 1); - if (cute::elect_one_sync()) - full_barriers[s]->arrive(); + // Issue TMAs + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad; + if (sf_stage_in_group_idx == 0) { + tma_copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M, + scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad))); + tma_copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N, + scheduler.template get_global_idx(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx)); + num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t); } - }); + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } } } else if (warp_idx == 1 and is_leader_cta) { // MMA issue warp @@ -307,101 +269,93 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { - // Wait tensor memory empty barrier arrival - auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; - tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); tcgen05_after_thread_sync(); - // Empty barrier arrival - auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) { - auto umma_arrive = [](const uint64_t* barrier) { - if constexpr (kNumMulticast == 1) { - cutlass::arch::umma_arrive(barrier); - } else { - constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; - cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); - } - }; - umma_arrive(reinterpret_cast(empty_barriers[s])); - - // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting - if (do_tmem_full_arrive) - umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); - }; - - // Launch MMAs - launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + // Do SF copy at certain stages + // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves + const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad; + if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) { + using cute_utccp_t = cute::conditional_t; + // SFA and SFB copy + // TODO: process shared memory descriptor by addition #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait TMA and SF-transpose arrival - with_sf_full_barriers[s]->wait(phase); - tcgen05_after_thread_sync(); + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + __syncwarp(); - // Do SF copy at certain stages - // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves - const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad; - if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) { - using cute_utccp_t = cute::conditional_t; - - // SFA and SFB copy - // TODO: process shared memory descriptor by addition - #pragma unroll - for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { - auto smem_ptr = smem_sfa[s] + i * kNumUTCCPAlignedElems; - replace_smem_desc_addr(sf_desc, smem_ptr); - cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); - } - #pragma unroll - for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { - auto smem_ptr = smem_sfb[s] + i * kNumUTCCPAlignedElems; - replace_smem_desc_addr(sf_desc, smem_ptr); - cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); - } - } - __syncwarp(); - - // Issue UMMA in the leader CTA - using cute_mma_t = cute::conditional_t, - cute::SM100_MMA_MXF8F6F4_2x1SM_SS>; - const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx); - const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s); - const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s); + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K); - cute_mma_t::fma(a_desc, b_desc, - accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, - k_iter > 0 or s > 0 or k > 0, - runtime_instr_desc, - kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32), - kTmemStartColOfSFB); - } + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc, + kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32), + kTmemStartColOfSFB); } - - // Commit to the mbarrier object - // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` - empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1); } + } - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - with_sf_full_barriers[s]->wait(phase); - empty_barrier_arrive(s, false); - } - }); - }); + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); } } else if (warp_idx == 2) { // UTCCP transposer @@ -418,43 +372,30 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, }; while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait TMA arrival - full_barriers[s]->wait(phase); - - // Transpose for UTCCP at certain stages - const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad; - if (sf_stage_in_group_idx == 0) { - #pragma unroll - for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) - utccp_required_smem_warp_transpose(smem_sfa[s] + i * kNumUTCCPAlignedElems); - #pragma unroll - for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) - utccp_required_smem_warp_transpose(smem_sfb[s] + i * kNumUTCCPAlignedElems); - // TODO: figure out whether the proxy fence is valid for 2-CTA cases - cutlass::arch::fence_view_async_shared(); - } - - // Arrive - with_sf_full_barriers[s]->arrive(0u); + // Transpose for UTCCP at certain stages + const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad; + if (sf_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); } - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait(phase); - with_sf_full_barriers[s]->arrive(0u); - } - }); + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } } } else if (warp_idx >= kNumNonEpilogueThreads / 32) { // Epilogue warp groups - const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads; const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, @@ -468,129 +409,113 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { - auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; - // Flush TMA stores - // NOTES: for the first store, we have to flush all previous TMA, - // as we don't share pipeline stages between two blocks - if (epilogue_thread_idx == 0) - cute::tma_store_wait<0>(); - cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); - // Wait UMMA arrival - tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); - tcgen05_after_thread_sync(); + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); - // Load from tensor memory into registers, and write shared memory with STSM - DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); - DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); - - // Iterate over M waves + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - // Issue every swizzled atom and pipeline STSM and TMA store - constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; + const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); + + // Store into shared memory #pragma unroll - for (uint32_t s = 0; s < kNumStores; ++ s) { - // Wait shared memory to be released - const uint32_t iter_idx = w * kNumStores + s; - if (iter_idx >= kNumTMAStoreStages) { - if (epilogue_thread_idx == 0) - cute::tma_store_wait(); - cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); - } + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); - // The pipeline stage - const auto tma_stage_idx = iter_idx % kNumTMAStoreStages; - const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; - const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); - // Store into shared memory - #pragma unroll - for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - // Reshape the atom in another view and swizzle - // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` - // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` - // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern - constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (i) : (bank_group_index % 8); - col ^= row % (kSwizzleCDMode / 16); - - // Source and destination memory address - uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset - w * BLOCK_N + // Wave offset - s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset - auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer - epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - - // Load from tensor memory, store into shared memory - uint32_t values[kNumElemsPerBankGroup]; - if constexpr (cute::is_same_v) { - // For FP32 output, read and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, - values[0], values[1], values[2], values[3]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); - } else { - // For BF16 output, read, cast and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, - values[0], values[1], values[2], values[3], - values[4], values[5], values[6], values[7]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, - cast_into_bf16_and_pack(values[0], values[1]), - cast_into_bf16_and_pack(values[2], values[3]), - cast_into_bf16_and_pack(values[4], values[5]), - cast_into_bf16_and_pack(values[6], values[7])); - } - } - - // Notify tensor memory empty (only at the leader CTA) arrival ASAP - // NOTES: only the last stage needs to do this - if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { - tcgen05_before_thread_sync(); - tmem_empty_barriers[accum_stage_idx]->arrive(0u); - } - __syncwarp(); - - // Synchronize all threads and issue TMA - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); - if (epilogue_thread_idx == 0) { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); - cute::tma_store_arrive(); + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); } } - } - }); - } - // Flush all stages in the pipeline to make TMA stores visible to the next kernel - if (epilogue_thread_idx == 0) - cute::tma_store_wait<0>(); + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + } + } // Deallocate tensor memory by warp 1 // NOTES: warp 0 is waiting TMA store if (epilogue_warp_idx == 1) Allocator().free(0, kNumTmemCols); } - - // To safely deconstruct all barriers, we need a cluster sync - // TODO: optimize it by another round of barrier waits - if constexpr (kNumMulticast > 1) - cute::cluster_sync(); #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh index e04db3c..658f883 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -22,7 +23,8 @@ template + GemmType kGemmType, typename cd_dtype_t, + typename epilogue_type_t> __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, @@ -88,8 +90,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); // Prefetch TMA descriptors at the very beginning - if (threadIdx.x == 0) { - // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + if (warp_idx == 0 and cute::elect_one_sync()) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_d); @@ -133,7 +134,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); // Initialize barriers - if (threadIdx.x == 0) { + if (warp_idx == 1 and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumStages; ++ i) { // Arrive at all CTAs @@ -149,9 +150,8 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, } // Make initialized barrier visible in async proxy - cutlass::arch::fence_view_async_shared(); cutlass::arch::fence_barrier_init(); - } else if (threadIdx.x >= 32 and threadIdx.x < 64) { + } else if (warp_idx == 2) { // Allocate tensor memory Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); } @@ -174,7 +174,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); // Register configurations constexpr uint32_t kNumNonEpilogueRegisters = 64; @@ -435,7 +435,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // as we don't share pipeline stages between two blocks if (epilogue_thread_idx_in_warpgroup == 0) cute::tma_store_wait<0>(); - cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync(); + cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx); // Write shared memory DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); @@ -449,13 +449,13 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, if (s >= kNumTMAStoreStages) { if (epilogue_thread_idx_in_warpgroup == 0) cute::tma_store_wait(); - cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync(); + cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx); } // The pipeline stage const auto tma_stage_idx = s % kNumTMAStoreStages; const auto m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_m, BLOCK_M, m_block_idx); - const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); const auto local_smem_cd = smem_cd[tma_stage_idx] + epilogue_warpgroup_idx * STORE_BLOCK_M * STORE_BLOCK_N; // Store into shared memory @@ -502,7 +502,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Synchronize all threads and issue TMA cute::tma_store_fence(); - cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync(); + cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx); if (epilogue_thread_idx_in_warpgroup == 0) { cute::SM90_TMA_STORE_2D::copy( &tensor_map_d, local_smem_cd, @@ -512,10 +512,6 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, } } - // Flush all stages in the pipeline to make TMA stores visible to the next kernel - if (epilogue_thread_idx_in_warpgroup == 0) - cute::tma_store_wait<0>(); - // Deallocate tensor memory by warp 1 // NOTES: warp 0 is waiting TMA store if (epilogue_warp_idx == 1) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh index 23045e1..9186e68 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -25,7 +25,8 @@ template + uint32_t kNumSMs, GemmType kGemmType, + typename cd_dtype_t> __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_bf16_gemm_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, @@ -44,7 +45,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; // Shared memory - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(cd_dtype_t); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); @@ -55,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, const uint32_t lane_idx = get_lane_idx(); // Prefetch TMA descriptors at the very beginning - if (threadIdx.x == kNumMathThreads) { + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_d); @@ -67,7 +68,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); // Data on shared memory - auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + auto smem_d = reinterpret_cast(smem_buffer); __nv_bfloat16* smem_a[kNumStages]; __nv_bfloat16* smem_b[kNumStages]; @@ -91,7 +92,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, } // Initialize barriers - if (threadIdx.x == kNumMathThreads) { + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < kNumStages; ++ i) { full_barriers[i]->init(1); @@ -99,7 +100,6 @@ sm90_bf16_gemm_impl(int* grouped_layout, } // Make initialized barrier visible in async proxy - cutlass::arch::fence_view_async_shared(); cutlass::arch::fence_barrier_init(); } @@ -125,14 +125,14 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); - if (threadIdx.x >= kNumMathThreads) { + if (warp_idx >= kNumMathThreads / 32) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); // NOTES: only one thread (or warp) will be used - if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) { + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { launch_k_iterations([&](uint32_t k_iter, auto divisible_type) { @@ -203,7 +203,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, } }; - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); // Launch MMAs launch_k_iterations([&](uint32_t k_iter, auto divisible_type) { @@ -237,11 +237,10 @@ sm90_bf16_gemm_impl(int* grouped_layout, for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) warpgroup_fence_operand(accum[i]); warpgroup_wait<0>(); - - // Notify barrier arrival at the last warpgroup wave - if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) - empty_barrier_arrive(s); } + + // Notify barrier arrival + empty_barrier_arrive(s); } // Wait unaligned cases @@ -256,7 +255,6 @@ sm90_bf16_gemm_impl(int* grouped_layout, constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; - DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, "Unaligned TMA store or too many TMA store instructions"); @@ -265,60 +263,76 @@ sm90_bf16_gemm_impl(int* grouped_layout, // Wait last TMA store to be finished if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) cute::tma_store_wait<0>(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); - // Write back to shared memory using STSM and issue TMA stores - DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; - auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + if constexpr (std::is_same_v) { + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // Swizzle or padding into the correct address - uint8_t* smem_ptr = nullptr; - if constexpr (kSwizzleDMode > 0) { - // Calculate the swizzling atom offset and in-atom offset - constexpr uint32_t kNumBankGroupBytes = 16; - auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); - // Reshape the atom in another view and swizzle - // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` - // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` - constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); - col ^= row % (kSwizzleDMode / 16); + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); - // Add back into the base pointer - // NOTES: think twice before modifying this, as changes may affect the number of instructions - smem_ptr = reinterpret_cast(smem_d) + // Base pointer - warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset - m_offset * kSwizzleDMode + // Wave offset - atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - } else { - // No swizzling, just padding - // TODO: support more cases - smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + } + else { + // Use `st.shared` if STSM is not available + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2); + auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); + st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); } - - // NOTES: only 16 lanes' addresses are used - SM90_U32x2_STSM_N::copy( - __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), - __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), - smem_ptr - ); } } cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); // Use TMA store to write back to global memory - // TODO: compatible with FP32 output constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { diff --git a/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh new file mode 100644 index 0000000..d40308e --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -0,0 +1,173 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + float *d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Shared memory + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M"); + DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads"); + DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads"); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + // Fill shared memory pointers + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block indices + const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + #pragma unroll + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1); + + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + const uint32_t& k_idx = sk_idx % SHAPE_K; + const uint32_t& s_idx = sk_idx / SHAPE_K; + + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + float accum[WGMMA::kNumAccum] = {0}; + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrivals + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, 1); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + empty_barriers[stage_idx]->arrive(); + } + + const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; + const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + if (col + i * 8 >= SHAPE_N) + break; + if (row < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 0) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 0], accum[i * 4 + 1])); + } + if (row + 8 < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 8) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 2], accum[i * 4 + 3])); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh index 28b5399..4c57cbe 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -1,3 +1,348 @@ #pragma once -// TODO: add implement \ No newline at end of file +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, + int* grouped_layout, + cute::TmaDescriptor* tensor_map_buffer, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_b_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = threadIdx.x % 32; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a_base); + cute::prefetch_tma_descriptor(&tensor_map_b_base); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Tensor maps on shared and global memory + auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * i); + }); + auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * (2 + i)); + }); + auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; }); + auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; }); + + // Data on shared memory + auto smem_d = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE)); + }); + auto smem_sfb = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); + }); + + // Barriers on shared memory + constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE); + auto full_barriers = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast(sizeof(Barrier)))); + }); + auto empty_barriers = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); + }); + + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // Load tensormap A/B to shared memory + if constexpr (kGemmType == GemmType::KGroupedContiguous) { + *smem_tensor_map_a[0] = tensor_map_a_base; + *smem_tensor_map_a[1] = tensor_map_a_base; + *smem_tensor_map_b[0] = tensor_map_b_base; + *smem_tensor_map_b[1] = tensor_map_b_base; + } + + // Initialize barriers + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Pipeline unroll control + constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages); + + // Register reconfigurations (more math registers are needed with unrolling) + constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24); + constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // TMA and MMA pipeline + const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { + return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase + }; + uint32_t iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base; + const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base; + uint32_t last_group_idx = kNumGroups, sum_k = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const uint32_t& m_idx = m_block_idx * BLOCK_M; + const uint32_t& n_idx = n_block_idx * BLOCK_N; + + if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) { + const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1; + const uint32_t& next_stage_idx = stage_idx ^ 1; + last_group_idx = scheduler.current_group_idx; + + // Prepare next tensor map + sum_k += scheduler.current_shape_k; + if (scheduler.next_group_idx < kNumGroups) { + tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + sum_k * shape_m); + tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + sum_k * shape_n); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); + *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); + *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]); + tensor_map_release_cta(); + } + + // Get current tensor map + if (scheduler.current_num_valid_groups > 0) { + tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]); + tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]); + current_tensor_map_a = gmem_tensor_map_a[stage_idx]; + current_tensor_map_b = gmem_tensor_map_b[stage_idx]; + } + } + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t& k_idx = k_block_idx * BLOCK_K; + const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; + tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma_copy(&tensor_map_sfb, reinterpret_cast(&full_barrier), smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma_copy(current_tensor_map_a, reinterpret_cast(&full_barrier), smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma_copy(current_tensor_map_b, reinterpret_cast(&full_barrier), smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) { + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; + const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Accumulation for WGMMA or CUDA promotion + DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); + const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); + const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + float2 scales_b[WGMMA::kNumAccum / 4]; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait TMA arrivals + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + full_barriers[stage_idx]->wait(phase); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0); + auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1); + + // Read B scales + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + scales_b[i] = ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + + // Promote with scales + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + const float &scale_b_0 = scales_b[i].x; + const float &scale_b_1 = scales_b[i].y; + final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; + } + } + + // Flush previous stores + if (warp_idx % 4 == 0 and cute::elect_one_sync()) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Store to D shared memory + const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Use TMA store to write back to global memory + if (warp_idx % 4 == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy( + &tensor_map_d, smem_d_0, n_block_idx * BLOCK_N, + current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 5a65d69..5a92d7d 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -18,15 +19,15 @@ namespace deep_gemm { using namespace deep_gemm::sm90; -template -__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) { +template +__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { if (num_former_iters == kNumFormerIters) { - inner_launch_k_iterations(func, cute::Int{}); + func(cute::Int{}); return; } if constexpr (kNumFormerIters + kGap <= kEnd) - outer_launch_k_iterations(inner_launch_k_iterations, func, num_former_iters); + dispatch_num_former_iters(num_former_iters, func); } template + uint32_t kNumSMs, GemmType kGemmType, + typename epilogue_type_t> __global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, @@ -69,14 +71,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); // Configs - constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; - const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); + const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); const uint32_t lane_idx = get_lane_idx(); // Prefetch TMA descriptors at the very beginning - if (threadIdx.x == kNumMathThreads) { - // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_sfa); @@ -90,35 +90,26 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Data on shared memory auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); - __nv_fp8_e4m3* smem_a[kNumStages]; - __nv_fp8_e4m3* smem_b[kNumStages]; - float* smem_sfa[kNumStages]; - float* smem_sfb; - - // TMA Barrier for both divisible and non-divisible cases - Barrier* full_barriers[kNumStages]; - Barrier* empty_barriers[kNumStages]; - - // Fill shared memory pointers - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); - smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); - smem_sfa[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE); - } - smem_sfb = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE)); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE); // Fill barriers auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); - #pragma unroll - for (uint32_t i = 0; i < kNumStages; ++ i) { - full_barriers[i] = barrier_start_ptr + i; - empty_barriers[i] = barrier_start_ptr + kNumStages + i; - } + auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); // Initialize barriers DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); - if (threadIdx.x == kNumMathThreads) { + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, // even with TMA multicast disabled, we want to make the behavior aligned #pragma unroll @@ -128,107 +119,72 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, } // Make initialized barrier visible in async proxy - cutlass::arch::fence_view_async_shared(); cutlass::arch::fence_barrier_init(); } // Synchronize all threads to make barrier visible in normal memory model (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); - // For pipeline unrolling - struct DivisibleK {}; - struct NotDivisibleK {}; - struct SkipComputation {}; - struct NotSkipComputation {}; - auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) { - constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; - constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; - constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; - - // NOTES: for too-many branches (> 5), we disable this optimization - // Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value - outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) { - if (skip_computation) { - for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) - func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type); - } else if (shape_k % kFullKOfAllStages == 0) { - for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) - func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); - } else { - for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) - func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); - func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type); - } - }, func, kShouldOptimize ? num_former_iters : 0); - }; - // Register reconfigurations constexpr uint32_t kNumTMARegisters = 40; constexpr uint32_t kNumMathRegisters = 232; // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); - if (threadIdx.x >= kNumMathThreads) { + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); // NOTES: only one thread (or warp) will be used - if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) { + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { - launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { - constexpr bool kHasDivisibleStages = cute::is_same_v; - constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - // Assign TMA multicast number into A and B - // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. - const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); - const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; - DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); - // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all - // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Wait consumer release - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + // Issue TMA A + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t k_idx = k_block_idx * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a); + tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx(shape_k_scales, 1, k_block_idx), + num_tma_multicast_a); - // Issue TMA A - constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; - auto& full_barrier = *full_barriers[s]; - uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), - num_tma_multicast_a); - tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), - smem_sfa[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(shape_k_scales, 1, k_idx / BLOCK_K), - num_tma_multicast_a); - - // Issue TMA B - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), - num_tma_multicast_b); - full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); - } - - // Wait unaligned cases - #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); - full_barriers[s]->arrive(); - } - }, false, 0); + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); + } } // To safely deconstruct distributed shared barriers, we need another round of empty waits if constexpr (kNumTMAMulticast > 1) { - #pragma unroll - for (uint32_t s = 0; s < kNumStages; ++ s) - empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); } } } else { @@ -239,6 +195,11 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = make_smem_desc(smem_b[0], 1); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { // Decide the number of scales B to load @@ -259,7 +220,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) st_shared(smem_sfb + i, __ldg(local_sfb + i)); } - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); // Accumulation for WGMMA or CUDA promotion constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2); @@ -267,90 +228,96 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; // Empty barrier arrival - auto empty_barrier_arrive = [&](uint32_t s) { + auto empty_barrier_arrive = [&]() { if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void(); } else { auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void(); } }; - // Launch MMAs - launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) { - constexpr bool kSkipComputation = cute::is_same_v; - constexpr bool kHasDivisibleStages = cute::is_same_v; - constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages); + // Skip useless computations + if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { + // The compiler must know the dynamic variable `num_former_iters`'s real value + constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; - #pragma unroll - for (uint32_t s = 0; s < kNumInnerStages; ++ s) { - // Read B scales - float scale_b_0 = ld_shared(smem_sfb + k_iter * kNumStages + s), scale_b_1; - // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks - if constexpr (not kMustUseUniformedScaleB) - scale_b_1 = ld_shared(smem_sfb + k_iter * kNumStages + s + shape_k_scales); + // Dispatch `num_former_iters` and launch MMAs + dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) { + #pragma unroll 8 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); - // Wait TMA arrivals - full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + // Read B scales + float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales); - // TODO: remove some useless computation for unaligned Ms - #pragma unroll - for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { - auto m_offset = local_idx * WAVE_BLOCK_M; + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); - // Read A scales - // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results - auto scale_a_0 = ld_shared(smem_sfa[s] + r_0 + m_offset); - auto scale_a_1 = ld_shared(smem_sfa[s] + r_1 + m_offset); + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; - // Commit WGMMA instructions - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_arrive(); - #pragma unroll - for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { - auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); - auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); - WGMMA::wgmma(desc_a, desc_b, accum, k); - } - warpgroup_commit_batch(); - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) - warpgroup_fence_operand(accum[i]); - warpgroup_wait<0>(); + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0 + m_offset); + auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1 + m_offset); - // Notify barrier arrival at the last warpgroup wave - if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) - empty_barrier_arrive(s); + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16; + b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16; + WGMMA::wgmma(a_desc, b_desc, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); - // Promote with scales - // NOTES: making it as predicates is very important for performance, comparing to two loops - float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; - float scale_0_1, scale_1_1; - if constexpr (not kMustUseUniformedScaleB) - scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(); - auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; - #pragma unroll - for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant - bool predicate = kMustUseUniformedScaleB or i < num_former_iters; - shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; - shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; - shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; - shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; - } + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } } - } - - // Wait unaligned cases + }); + } else { #pragma unroll - for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { - full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); - empty_barrier_arrive(s); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + empty_barrier_arrive(); } - }, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters); + } // TMA checks constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); @@ -364,7 +331,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Wait last TMA store to be finished if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) cute::tma_store_wait<0>(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); // Write back to shared memory using STSM and issue TMA stores DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); @@ -413,7 +380,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, } } cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); // Use TMA store to write back to global memory // TODO: compatible with FP32 output @@ -423,7 +390,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, - n_block_idx * BLOCK_N + in_block_n_offset, + epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + in_block_n_offset), scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); cute::tma_store_arrive(); } diff --git a/deep_gemm/testing/bench.py b/deep_gemm/testing/bench.py index 7e77866..8bba422 100644 --- a/deep_gemm/testing/bench.py +++ b/deep_gemm/testing/bench.py @@ -91,7 +91,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, # Profile suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress with suppress(): - schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) if not using_nsys else None profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() with profiler: for i in range(2): @@ -112,10 +112,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, is_tuple = isinstance(kernel_names, tuple) prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names - assert all([isinstance(name, str) for name in kernel_names]) if not with_multiple_kernels: for name in kernel_names: - assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table' # Save chrome traces if trace_path is not None: @@ -136,6 +135,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, total_time += float(time_str.replace(unit, '')) / scale * int(num_str) total_num += int(num_str) break - kernel_times.append(total_time / total_num) + kernel_times.append(total_time / total_num if total_num > 0 else 0) return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/deep_gemm/utils/math.py b/deep_gemm/utils/math.py index 46804e7..1a47e15 100644 --- a/deep_gemm/utils/math.py +++ b/deep_gemm/utils/math.py @@ -16,13 +16,16 @@ def ceil_to_ue8m0(x: torch.Tensor): def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(1) % 128 == 0 + assert x.dim() == 2 m, n = x.shape - x_view = x.view(m, -1, 128) + padded_n = align(n, 128) + x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf - return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: @@ -54,4 +57,4 @@ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) - sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) - return x_scaled, sf.squeeze() \ No newline at end of file + return x_scaled, sf.squeeze() diff --git a/tests/generators.py b/tests/generators.py index 82cdbdc..d856217 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -1,7 +1,7 @@ import enum import random import torch -from typing import Generator, Tuple, List +from typing import Generator, List from deep_gemm.utils import ( align, ceil_div, @@ -11,7 +11,6 @@ from deep_gemm.utils import ( class KernelType(enum.Enum): - # For SM100 GEMMs Kernel1D1D = 0 Kernel1D2D = 1 KernelNoSF = 2 @@ -48,62 +47,87 @@ def get_ue8m0_usage(kernel_type: KernelType) -> bool: return kernel_type.is_1d1d() -def get_kernel_types(use_bf16: bool = False) -> tuple: - if use_bf16: +def get_kernel_types(dtype: torch.dtype) -> tuple: + if dtype == torch.bfloat16: return (KernelType.KernelNoSF, ) - return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D) + + # TODO: SM100 1D2D kernels are going to be deprecated + # But if you want to test it, please use: + # `(KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)` + return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, ) -def get_out_dtype() -> tuple: - return (torch.bfloat16, ) if get_arch_major() == 9 else (torch.bfloat16, torch.float) +def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator: + for major_a in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor): + for major_b in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor): + if major_a.is_mn_major() and not allow_a_mn_major: + continue + if major_b.is_mn_major() and not allow_b_mn_major: + continue + yield major_a, major_b -def get_major_ab(freeze_a: bool) -> tuple: - # TODO: test other major-ness for SM90 BF16 GEMMs - if get_arch_major() == 9: - return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), ) - if freeze_a: - return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor) - return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor), \ - (MajorTypeAB.MNMajor, MajorTypeAB.KMajor), (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) +def enumerate_normal(dtype: torch.dtype) -> Generator: + assert dtype in (torch.float8_e4m3fn, torch.bfloat16) + + fp32_output_nk = [(256, 7168), (129280, 7168)] + bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)] + m_fwd_list, m_bwd_list = [128, 4096], [4096, ] + nk_list = bf16_output_nk + + # Only BF16 GEMM needs FP32 outputs + if dtype == torch.bfloat16: + nk_list += fp32_output_nk + + for kernel_type in get_kernel_types(dtype): + # Forward + for m in m_fwd_list: + for n, k in nk_list: + out_dtype = torch.float if (n, k) in fp32_output_nk else torch.bfloat16 + yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype + + # TODO: support BF16 SM90 MN-major kernels + if dtype == torch.bfloat16 and get_arch_major() == 9: + continue + + # Backward + for m in m_bwd_list: + for n, k in nk_list: + override_major = MajorTypeAB.MNMajor + override_kernel_type = kernel_type + if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: + override_major = MajorTypeAB.KMajor + override_kernel_type = KernelType.Kernel1D1D + yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad + yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad + yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad -def enumerate_normal(use_bf16: bool = False) -> Generator: - for kernel_type in get_kernel_types(use_bf16): - for m in (128, 4096): - for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]: - for major_a, major_b in get_major_ab(False): - for out_dtype in get_out_dtype(): - for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True): - yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype - - -def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator: - for kernel_type in get_kernel_types(use_bf16): +def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator: + for kernel_type in get_kernel_types(dtype): for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)): - for major_a, major_b in get_major_ab(True): + for major_a, major_b in get_major_ab(False, get_arch_major() > 9): yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b -def enumerate_m_grouped_masked() -> Generator: +def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator: max_m = 4096 - for kernel_type in get_kernel_types(): + for kernel_type in get_kernel_types(dtype): for num_groups, m in ((1, 1024), (2, 512), (4, 256)): for n, k in ((4096, 7168), (7168, 2048), ): yield kernel_type, num_groups, max_m, m, n, k def enumerate_k_grouped_contiguous(): - # TODO: support SM90 kernels - if get_arch_major() == 9: - return [] - + # Only K-major is supported for SM90 + major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 \ + else (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) # Must with FP32 accumulation and 1D1D kernels for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64 ( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32 (16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16 ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] - yield num_groups, m, n, ks, expected_k_per_group + yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group def enumerate_sf_layout(): @@ -134,6 +158,7 @@ def enumerate_transpose(): def generate_normal(m: int, n: int, k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, accumulate: bool, out_dtype: torch.dtype, + kernel_type: KernelType, use_ue8m0: bool = False, use_bf16: bool = False): a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) @@ -147,7 +172,9 @@ def generate_normal(m: int, n: int, k: int, b = b if major_b.is_k_major() else b.T.contiguous().T return a, b, c, d, ref_d - a_fp8, b_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0), per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0) + a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0) + b_fp8 = per_token_cast_to_fp8(b, use_ue8m0=use_ue8m0) if kernel_type.is_1d1d() and accumulate \ + else per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0) a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1]) b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1]) return a_fp8, b_fp8, c, d, ref_d @@ -214,7 +241,7 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: return a_fp8, b_fp8, masked_m, d, ref_d -def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool): +def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], use_ue8m0: bool): assert get_mk_alignment_for_contiguous_layout() % 128 == 0 k = sum(ks) @@ -232,4 +259,20 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int] a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0) b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0) + + # Transpose for K Major A/B + if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor): + a, sfa = a_fp8 + b, sfb = b_fp8 + new_a = torch.empty((sum(ks) * m, ), dtype=a.dtype, device=a.device) + new_b = torch.empty((sum(ks) * n, ), dtype=b.dtype, device=b.device) + prefix = 0 + for K in ks: + new_a[prefix * m : (prefix + K) * m] = a[prefix : prefix + K, ].T.flatten() + new_b[prefix * n : (prefix + K) * n] = b[prefix : prefix + K, ].T.flatten() + prefix += K + a_fp8, b_fp8 = (new_a, sfa.T), (new_b, sfb.T) + else: + assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) + return k, a_fp8, b_fp8, c, d, ref_d diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 0000000..98ffa20 --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,64 @@ +import random +import torch +from typing import Tuple + +import deep_gemm +from deep_gemm.testing import bench_kineto, calc_diff, count_bytes +from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8 + +from generators import get_arch_major, generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB + + +def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]): + left, mid, right = head_splits + m, n = d.shape + assert n % (left + right) == 0 + num_heads = n // (left + right) + + # Split and insert padding tensor + d = d.view(m, num_heads, -1) + d_left = d[:, :, :left] + d_right = d[:, :, -right:] + + d_mid = torch.zeros((m, num_heads, mid), dtype=d.dtype, device=d.device) + return torch.cat([d_left, d_mid, d_right], dim=2).view(m, -1) + + +def test_gemm_skip_head_mid() -> None: + print('Testing GEMM skip head mid:') + head_splits = (128, 64, 128) + + major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor + out_dtype, accumulate = torch.bfloat16, False + + for kernel_type in get_kernel_types(dtype=torch.float8_e4m3fn): + for m in (128, 4096): + for n, k in [(32768, 512), (8192, 512)]: + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + a, b, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0) + d = apply_skip_head_mid(d, head_splits) + ref_d = apply_skip_head_mid(ref_d, head_splits) + + deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {kernel_opt}, {diff:.5f}' + + t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast), + 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(0) + random.seed(0) + + test_gemm_skip_head_mid() diff --git a/tests/test_bf16.py b/tests/test_bf16.py index 790f700..34e0b42 100644 --- a/tests/test_bf16.py +++ b/tests/test_bf16.py @@ -7,6 +7,7 @@ from deep_gemm.testing import ( calc_diff, count_bytes ) from generators import ( + get_arch_major, enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked ) @@ -14,14 +15,18 @@ from generators import ( def test_gemm() -> None: print('Testing GEMM:') - for _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(use_bf16=True): + for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16): + # TODO: support accumulation for SM90 BF16 GEMM + if get_arch_major() == 9 and accumulate: + continue + major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' out_opt = 'FP32' if out_dtype == torch.float else 'BF16' acc_opt = f'acc={int(accumulate)}' for test_alias in (False, True): - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True) + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) func_name = f'bf16_gemm_{major_opt.lower() if test_alias else "nt"}' if test_alias: a = a if major_a.is_k_major() else a.T @@ -31,28 +36,22 @@ def test_gemm() -> None: diff = calc_diff(d, ref_d) assert diff < 0.0001, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, ' f'{diff:.5f}, alias={test_alias}') - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True) + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) - cublas_t = 0 t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True) - if accumulate == 0 and out_dtype == torch.bfloat16: - # noinspection PyBroadException - try: - cublas_t = bench_kineto(lambda: a @ b.T, 'nvjet', suppress_kineto_output=True) - except Exception: - pass - print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, {out_opt}, {acc_opt}): ' - f'{t * 1e6:4.0f} us | ' + cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:5.0f} us | ' f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' - f'{cublas_t / t:.2f}x cuBLAS') + f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') print() def test_m_grouped_gemm_contiguous() -> None: print('Testing m-grouped contiguous GEMM:') - for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(use_bf16=True): + for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' @@ -85,7 +84,7 @@ def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.bfloat16): # Test correctness for i in range(10): a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True) @@ -111,6 +110,27 @@ def test_m_grouped_gemm_masked() -> None: print() +def test_cublaslt_gemm() -> None: + print('Testing cuBLASLt GEMM:') + for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) + deep_gemm.cublaslt_gemm_nt(a, b, d, c=c) + diff = calc_diff(d, ref_d) + assert diff < 5e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})' + + t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), 'nvjet', suppress_kineto_output=True,) + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:5.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s') + print() + + if __name__ == '__main__': torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -121,5 +141,9 @@ if __name__ == '__main__': print(f' > {deep_gemm.__path__}\n') test_gemm() - test_m_grouped_gemm_contiguous() - test_m_grouped_gemm_masked() + # TODO: support SM100 + if get_arch_major() == 9: + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() + + test_cublaslt_gemm() diff --git a/tests/test_einsum.py b/tests/test_einsum.py new file mode 100644 index 0000000..cfdd453 --- /dev/null +++ b/tests/test_einsum.py @@ -0,0 +1,85 @@ +import random +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench, bench_kineto, + calc_diff, count_bytes +) + + +def test_bmk_bnk_mn() -> None: + print('Testing "bmk, bnk -> mn":') + for s in (129, 4096, 8192): + for m, n, k in [(128, 384, 128), (256, 256, 256), (384, 128, 384)]: + for dtype in (torch.float, torch.bfloat16): + a = torch.randn((s, m, k), dtype=torch.bfloat16, device='cuda') + b = torch.randn((s, n, k), dtype=torch.bfloat16, device='cuda') + d = torch.randn((m, n), dtype=dtype, device='cuda') + c = d if dtype == torch.float else None + + # Test correctness + ref_d = (c if dtype == torch.float else 0) + torch.bmm(a.float(), b.float().mT).sum(0) + deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c) + assert calc_diff(d, ref_d) < 1e-5 + + t = bench_kineto(lambda: deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c), 'bmn_bnk_mn_gemm_impl', suppress_kineto_output=True) + print(f' > Perf (b={s:4.0f}, {m=}, {n=}, {k=}, {"FP32" if dtype == torch.float else "BF16"}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * s * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b) + (d.numel() * 4)) / 1e9 / t:4.0f} GB/s') + print() + + +def test_bhr_hdr_bhd(): + print('Testing "bhr, hdr -> bhd":') + for b in (128, 4096, 8192): + for h, r, d in [(128, 512, 128)]: + x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16) + fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16) + y = fy[:, :, :r] + ref_z = torch.einsum('bhr,hdr->bhd', x, y) + z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16) + deep_gemm.einsum('bhr,hdr->bhd', x, y, z) + assert calc_diff(z, ref_z) < 1e-10 + + t = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | ' + f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s') + print() + + +def test_bhd_hdr_bhr(): + print('Testing "bhd, hdr -> bhr":') + for b in (128, 4096, 8192): + for h, r, d in [(128, 512, 128)]: + x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16) + fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16) + y = fy[:, :, :r] + ref_z = torch.einsum('bhd,hdr->bhr', x, y) + z = torch.empty((b, h, r), device='cuda', dtype=torch.bfloat16) + deep_gemm.einsum('bhd,hdr->bhr', x, y, z) + assert calc_diff(z, ref_z) < 1e-10 + + t = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | ' + f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_bmk_bnk_mn() + test_bhr_hdr_bhd() + test_bhd_hdr_bhr() diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 0c7d3ce..7415e07 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -10,7 +10,7 @@ from deep_gemm.testing import ( ) from generators import ( - KernelType, get_ue8m0_usage, + KernelType, get_arch_major, get_ue8m0_usage, enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous ) @@ -18,7 +18,7 @@ from generators import ( def test_gemm() -> None: print('Testing GEMM:') - for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(): + for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' out_opt = 'FP32' if out_dtype == torch.float else 'BF16' @@ -26,42 +26,35 @@ def test_gemm() -> None: kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' use_ue8m0 = get_ue8m0_usage(kernel_type) disable_ue8m0_cast = not use_ue8m0 + recipe = (1, 1, 128) if kernel_type.is_1d1d() and accumulate else None for test_alias in (False, True): - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0) + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0) func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}' if test_alias: a = a if major_a.is_k_major() else (a[0].T, a[1].T) b = b if major_b.is_k_major() else (b[0].T, b[1].T) assert a[0].is_contiguous() and b[0].is_contiguous() - getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast) + getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe) diff = calc_diff(d, ref_d) assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, ' f'{diff:.5f}, alias={test_alias}') - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0) - # Test launch overhead - launch_start_t = time.time_ns() - deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast) - launch_end_t = time.time_ns() - torch.cuda.synchronize() - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast) - - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' - f'launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | ' - f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s') + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0) + t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe), + 'fp8_gemm', suppress_kineto_output=True) + cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:4.0f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' + f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') print() def test_m_grouped_gemm_contiguous() -> None: print('Testing m-grouped contiguous GEMM:') - for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(): + for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' @@ -86,7 +79,7 @@ def test_m_grouped_gemm_contiguous() -> None: deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast) t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}): ' + print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}): ' f'{t * 1e6:4.0f} us | ' f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') @@ -97,7 +90,7 @@ def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.float8_e4m3fn): kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' use_ue8m0 = get_ue8m0_usage(kernel_type) disable_ue8m0_cast = not use_ue8m0 @@ -130,26 +123,31 @@ def test_m_grouped_gemm_masked() -> None: def test_k_grouped_gemm_contiguous() -> None: print('Testing k-grouped contiguous GEMM:') - for num_groups, m, n, ks, expected_k_per_group in enumerate_k_grouped_contiguous(): + k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \ + else deep_gemm.k_grouped_fp8_gemm_tn_contiguous + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(): use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) for test_empty_groups in (False, True): new_ks = copy.deepcopy(ks) - if test_empty_groups: + if test_empty_groups and len(ks) > 1: new_ks[random.randint(0, num_groups - 1)] = 0 - k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, new_ks, use_ue8m0=use_ue8m0) + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0) new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') - deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, new_ks, new_ks_tensor, c=c) - diff = calc_diff(d, ref_d) - assert diff < 0.001, f'{m=}, {n=}, {k=}, {i=}, {diff:.5f}' + k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c) + + do_check = True + if do_check: + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}' # Test performance - k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, ks, use_ue8m0=use_ue8m0) + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0) ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') # noinspection PyShadowingNames def test_func(): - deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=c) + k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c) t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): '