From f85ec649d76846552cfd637b6c99fb9c985fd9eb Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Fri, 15 Aug 2025 18:32:35 +0800 Subject: [PATCH] Make various updates and fixes: (#164) - Add BF16 support for SM90 and SM100 - Refactor Python APIs - Other fixes and code refactoring --- README.md | 2 +- csrc/apis/gemm.hpp | 471 +++++++++++++++++ csrc/apis/layout.hpp | 85 +++ csrc/apis/runtime.hpp | 28 + csrc/jit/compiler.hpp | 6 +- csrc/jit/handle.hpp | 2 +- csrc/jit/kernel_runtime.hpp | 4 +- csrc/jit_kernels/heuristics/common.hpp | 5 +- csrc/jit_kernels/heuristics/sm100.hpp | 4 +- csrc/jit_kernels/heuristics/sm90.hpp | 10 +- csrc/jit_kernels/impls/sm100_bf16_gemm.hpp | 143 +++++ .../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp | 5 +- .../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp | 5 +- csrc/jit_kernels/impls/sm90_bf16_gemm.hpp | 229 ++++++++ csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 5 +- csrc/jit_kernels/impls/smxx_layout.hpp | 63 ++- csrc/python_api.cpp | 405 +------------- csrc/utils/exception.hpp | 13 +- deep_gemm/__init__.py | 48 +- .../include/deep_gemm/common/scheduler.cuh | 22 +- .../include/deep_gemm/common/sm90_utils.cuh | 76 +++ deep_gemm/include/deep_gemm/common/utils.cuh | 18 + .../deep_gemm/impls/sm100_bf16_gemm.cuh | 497 +++++++++++++++++- .../deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh | 6 +- .../deep_gemm/impls/sm90_bf16_gemm.cuh | 342 +++++++++++- .../deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 2 +- .../include/deep_gemm/impls/smxx_layout.cuh | 39 ++ pyproject.toml | 3 - setup.py | 4 + tests/generators.py | 56 +- tests/test_bf16.py | 125 +++++ tests/test_fp8.py | 4 +- tests/test_layout.py | 46 +- tests/test_lazy_init.py | 15 + 34 files changed, 2293 insertions(+), 495 deletions(-) create mode 100644 csrc/apis/gemm.hpp create mode 100644 csrc/apis/layout.hpp create mode 100644 csrc/apis/runtime.hpp create mode 100644 csrc/jit_kernels/impls/sm100_bf16_gemm.hpp create mode 100644 csrc/jit_kernels/impls/sm90_bf16_gemm.hpp delete mode 100644 pyproject.toml create mode 100644 tests/test_bf16.py create mode 100644 tests/test_lazy_init.py diff --git a/README.md b/README.md index 491574c..c2e23f9 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ We also provide a K-axis-grouped API for MoE weight backward (with M and N must During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions. -Use `fp8_m_grouped_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. +Use `m_grouped_fp8_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. #### Utilities diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp new file mode 100644 index 0000000..a6bd344 --- /dev/null +++ b/csrc/apis/gemm.hpp @@ -0,0 +1,471 @@ +#pragma once + +#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/sm90_bf16_gemm.hpp" +#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/sm100_bf16_gemm.hpp" + +#include "layout.hpp" + +namespace deep_gemm::gemm { + +static void fp8_gemm_nt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + if (fp8_requires_k_major()) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + } + + // C/D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto& [m , k ] = get_shape<2>(a.first); + const auto& [n , k_] = get_shape<2>(b.first); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Check C as well + if (c.has_value()) { + check_major_type_cd(c.value()); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + } + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast); + const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast); + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { + sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void fp8_gemm_nn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, compiled_dims, disable_ue8m0_cast); +} + +static void fp8_gemm_tn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, + {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, compiled_dims, disable_ue8m0_cast); +} + +static void fp8_gemm_tt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, + d, c, recipe, compiled_dims, disable_ue8m0_cast); +} + +static void m_grouped_fp8_gemm_nt_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + if (fp8_requires_k_major()) + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(m_indices.is_contiguous()); + + // Type and shape checks + const auto& [m, k] = get_shape<2>(a.first); + const auto& [num_groups, n, k_] = get_shape<3>(b.first); + const auto& [m_, n_] = get_shape<2>(d); + const auto& m__ = static_cast(m_indices.numel()); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Do nothing if empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast); + const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { + sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void m_grouped_fp8_gemm_nn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const std::optional>& recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, + d, m_indices, recipe, compiled_dims, disable_ue8m0_cast); +} + +static void m_grouped_fp8_gemm_nt_masked(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& expected_m, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + // Type and shape checks + const auto& [num_groups, m, k] = get_shape<3>(a.first); + const auto& [num_groups_, n, k_] = get_shape<3>(b.first); + const auto& [num_groups__, m_, n_] = get_shape<3>(d); + const auto& num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Transform scaling factors + if (not recipe.has_value()) + recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); + const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast); + const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { + sm100_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::tuple& recipe, + const std::string& compiled_dims) { + // Must be 1D1D kernel + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + + // Contiguity checks + DG_HOST_ASSERT(a.first.is_contiguous()); + DG_HOST_ASSERT(b.first.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + if (c.has_value()) { + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().is_contiguous()); + } + + // Do nothing if empty + if (std::accumulate(ks.begin(), ks.end(), 0) == 0) + return; + + // Transform SF with padding + const auto& [_, m] = get_shape<2>(a.first); + const auto& [__, n] = get_shape<2>(b.first); + const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe); + const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 10) { + fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bf16_gemm_nt(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + // Shape must be `[M, K] @ [N, K].T` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + + // C/D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto& [m , k ] = get_shape<2>(a); + const auto& [n , k_] = get_shape<2>(b); + const auto& [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Check C as well + if (c.has_value()) { + check_major_type_cd(c.value()); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + } + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10) { + sm100_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bf16_gemm_nn(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a, b.transpose(0, 1), d, c, compiled_dims); +} + +static void bf16_gemm_tn(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c, compiled_dims); +} + +static void bf16_gemm_tt(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a.transpose(0, 1), b, d, c, compiled_dims); +} + +static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& m_indices, + const std::string& compiled_dims) { + // Shape must be `[M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(m_indices.is_contiguous()); + + // Type and shape checks + const auto& [m, k] = get_shape<2>(a); + const auto& [num_groups, n, k_] = get_shape<3>(b); + const auto& [m_, n_] = get_shape<2>(d); + const auto& m__ = static_cast(m_indices.numel()); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Do nothing if empty + if (m == 0) + return; + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices, + num_groups, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& masked_m, + const int& expected_m, const std::string& compiled_dims) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto& major_a = get_major_type_ab(a); + const auto& major_b = get_major_type_ab(b); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + // Type and shape checks + const auto& [num_groups, m, k] = get_shape<3>(a); + const auto& [num_groups_, n, k_] = get_shape<3>(b); + const auto& [num_groups__, m_, n_] = get_shape<3>(d); + const auto& num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void register_apis(pybind11::module_& m) { + // FP8 GEMMs + m.def("fp8_gemm_nt", &fp8_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_gemm_nn", &fp8_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_gemm_tn", &fp8_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "mn", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_gemm_tt", &fp8_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "mn", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), + py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), + py::arg("ks_tensor"), py::arg("c") = std::nullopt, + py::arg("recipe") = std::make_tuple(1, 1, 128), + py::arg("compiled_dims") = "mn"); + + // BF16 GEMMs + m.def("bf16_gemm_nt", &bf16_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "nk"); + m.def("bf16_gemm_nn", &bf16_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "nk"); + m.def("bf16_gemm_tn", &bf16_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "mn"); + m.def("bf16_gemm_tt", &bf16_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "mn"); + m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), + py::arg("compiled_dims") = "nk"); + m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("compiled_dims") = "nk"); +} + +} // namespace deep_gemm::gemm diff --git a/csrc/apis/layout.hpp b/csrc/apis/layout.hpp new file mode 100644 index 0000000..27c4120 --- /dev/null +++ b/csrc/apis/layout.hpp @@ -0,0 +1,85 @@ +#pragma once + +#include "../utils/layout.hpp" +#include "../jit_kernels/impls/smxx_layout.hpp" + +namespace deep_gemm::layout { + +static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, + const int& mn, const int& k, + const std::tuple& recipe, + const std::optional& num_groups, + const bool& is_sfa, + const bool& disable_ue8m0_cast) { + const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe); + const auto& gran_k = std::get<2>(recipe); + const auto& arch_major = device_runtime->get_arch_major(); + + // Pre-transform checks + check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups); + + // (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) + return get_mn_major_tma_aligned_tensor(sf); + + // (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) { + DG_HOST_ASSERT(not disable_ue8m0_cast); + return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf); + } + + // (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous + if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat); + + // (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) { + DG_HOST_ASSERT(not disable_ue8m0_cast); + const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128)); + return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted); + } + + // (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); + + DG_HOST_UNREACHABLE("Unknown SF transformation"); +} + +static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::tuple& recipe) { + DG_HOST_ASSERT(sf.dim() == 2); + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + const auto& arch_major = device_runtime->get_arch_major(); + + // FP32 on SM90 + if (sf.scalar_type() == torch::kFloat and arch_major == 9) + DG_HOST_UNREACHABLE("Unimplemented"); + + // FP32 on SM100 + if (sf.scalar_type() == torch::kFloat and arch_major == 10) + return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks); + + // INT on SM100 + if (sf.scalar_type() == torch::kFloat and arch_major == 10) + DG_HOST_UNREACHABLE("Unimplemented"); + + DG_HOST_UNREACHABLE("Unknown cases"); +} + +static void register_apis(pybind11::module_& m) { + m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout, + py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"), + py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false, + py::arg("disable_ue8m0_cast") = false); + + m.def("get_tma_aligned_size", &get_tma_aligned_size); + m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout); + m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor); + m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor); + m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); +} + +} // namespace deep_gemm::layout diff --git a/csrc/apis/runtime.hpp b/csrc/apis/runtime.hpp new file mode 100644 index 0000000..9ef4207 --- /dev/null +++ b/csrc/apis/runtime.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "../jit/compiler.hpp" +#include "../jit/device_runtime.hpp" + +namespace deep_gemm::runtime { + +static void register_apis(pybind11::module_& m) { + m.def("set_num_sms", [&](const int& new_num_sms) { + device_runtime->set_num_sms(new_num_sms); + }); + m.def("get_num_sms", [&]() { + return device_runtime->get_num_sms(); + }); + m.def("set_tc_util", [&](const int& new_tc_util) { + device_runtime->set_tc_util(new_tc_util); + }); + m.def("get_tc_util", [&]() { + return device_runtime->get_tc_util(); + }); + + m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) { + Compiler::prepare_init(library_root_path, cuda_home_path_by_python); + KernelRuntime::prepare_init(cuda_home_path_by_python); + }); +} + +} // namespace deep_gemm::runtime diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp index 0e84b48..46e92b6 100644 --- a/csrc/jit/compiler.hpp +++ b/csrc/jit/compiler.hpp @@ -35,10 +35,10 @@ public: } static void prepare_init(const std::string& library_root_path, - const std::string& cuda_home_path_by_torch) { + const std::string& cuda_home_path_by_python) { Compiler::library_root_path = library_root_path; Compiler::library_include_path = Compiler::library_root_path / "include"; - Compiler::cuda_home = cuda_home_path_by_torch; + Compiler::cuda_home = cuda_home_path_by_python; Compiler::library_version = get_library_version(); } @@ -64,6 +64,8 @@ public: get_env("DG_JIT_CPP_STANDARD", 20)); if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) flags += " --ptxas-options=--verbose"; + if (get_env("DG_JIT_WITH_LINEINFO", 0)) + flags += " -Xcompiler -rdynamic -lineinfo"; } virtual ~Compiler() = default; diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index 754b299..1875d54 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -8,7 +8,7 @@ namespace deep_gemm { -#if CUDART_VERSION >= 12080 or defined(DG_JIT_USE_DRIVER_API) +#if CUDART_VERSION >= 12080 and not defined(DG_JIT_USE_DRIVER_API) // Use CUDA runtime API using LibraryHandle = cudaLibrary_t; diff --git a/csrc/jit/kernel_runtime.hpp b/csrc/jit/kernel_runtime.hpp index 5a6022c..42b7b4c 100644 --- a/csrc/jit/kernel_runtime.hpp +++ b/csrc/jit/kernel_runtime.hpp @@ -64,8 +64,8 @@ public: kernel = load_kernel(cubin_path, symbol_names[0], &library); } - static void prepare_init(const std::string& cuda_home_path_by_torch) { - cuda_home = cuda_home_path_by_torch; + static void prepare_init(const std::string& cuda_home_path_by_python) { + cuda_home = cuda_home_path_by_python; } static bool check_validity(const std::filesystem::path& dir_path) { diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index a7371b7..3ed4d2a 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -1,6 +1,7 @@ #pragma once #include "../../utils/math.hpp" +#include "../../utils/layout.hpp" namespace deep_gemm { @@ -146,7 +147,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, const bool& with_accumulation, const int& num_sms) { - DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); // Select M/N block sizes @@ -179,7 +180,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k for (const auto& block_n: block_ns) { const int& num_waves = get_num_waves(block_m, block_n); const auto& last_util = get_last_wave_util(block_m, block_n); - if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n)) + if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n, block_k)) continue; bool success = false; diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index e26b69f..4e58289 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -43,7 +43,7 @@ struct SM100ArchSpec { static bool is_block_size_legal(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, - const int& block_m, const int& block_n) { + const int& block_m, const int& block_n, const int& block_k) { // TODO: consider more carefully for BF16 GEMMs // 2SM BF16 UMMA does not support `N % 32 != 0` if (ab_dtype == torch::kBFloat16 and block_n % 32 != 0) @@ -106,7 +106,7 @@ struct SM100ArchSpec { const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) { constexpr static int layout_ad_m = 128; - return (kernel_type == KernelType::Kernel1D1D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2; + return (kernel_type != KernelType::Kernel1D2D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2; } static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index a1cb5b4..16ca018 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -30,15 +30,19 @@ struct SM90ArchSpec { static bool is_block_size_legal(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, - const int& block_m, const int& block_n) { + const int& block_m, const int& block_n, const int& block_k) { // FP32 output does not support `block_m == 256` if (cd_dtype == at::kFloat and block_m == 256) return false; + // TODO: more general block N selection // Must be some fixed block N selections - if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 or block_n != 152)) + if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 and block_n != 152)) return false; - if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 or block_n != 160)) + + // Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k` + // Or too many register spills + if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192)) return false; // Avoid bank conflicts for FP32 output diff --git a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp new file mode 100644 index 0000000..033a7b7 --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -0,0 +1,143 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BF16GemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_c; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_bf16_gemm_impl< + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, + {}, {}, {}, + {} + >); +}}; +)", + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.num_groups, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype), + args.gemm_config.tc_util); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_c, args.tensor_map_d)); + } +}; + +static void sm100_bf16_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + // TODO: test other Ks + DG_HOST_ASSERT(k % 64 == 0); + const auto& config = get_best_config( + GemmType::Normal, KernelType::KernelNoSF, + m, n, k, 1, major_a, major_b, + torch::kBFloat16, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const auto& cd = c.value_or(d); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_c = make_tma_cd_desc(cd, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(cd.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Duplicate the accumulator if necessary + if (c.has_value()) { + if (c->data_ptr() == d.data_ptr()) { + DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + } else { + // ReSharper disable once CppExpressionWithoutSideEffects + d.copy_(c.value()); + } + } + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_c = tensor_map_c, + .tensor_map_d = tensor_map_d + }; + const auto& code = SM100BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_bf16_gemm", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index d4e573b..67272d9 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -3,6 +3,7 @@ #include #include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" #include "../../jit/kernel_runtime.hpp" #include "../../utils/exception.hpp" #include "../../utils/format.hpp" @@ -155,7 +156,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::MGroupedContiguous, KernelType::Kernel1D1D, - m, n, k, num_groups, major_a, major_b, + m, n, k, 1, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), false, device_runtime->get_num_sms()); @@ -202,7 +203,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con SM100FP8Gemm1D1DRuntime::launch(runtime, args); } -static void sm100_fp8_m_grouped_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, +static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const torch::Tensor& d, const torch::Tensor& masked_m, diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp index c33a450..727d1b7 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp @@ -3,6 +3,7 @@ #include #include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" #include "../../jit/kernel_runtime.hpp" #include "../../utils/exception.hpp" #include "../../utils/format.hpp" @@ -136,7 +137,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::MGroupedContiguous, KernelType::Kernel1D2D, - m, n, k, num_groups, major_a, major_b, + m, n, k, 1, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), false, device_runtime->get_num_sms()); @@ -179,7 +180,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con SM100FP8Gemm1D2DRuntime::launch(runtime, args); } -static void sm100_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, +static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const torch::Tensor& d, const torch::Tensor& masked_m, diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp new file mode 100644 index 0000000..ea29883 --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -0,0 +1,229 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BF16GemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + const std::string& compiled_dims; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_bf16_gemm_impl< + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, {}, + {}, {} + >); +}}; +)", + // TODO: add CD dtype + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.num_groups, + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, args.gemm_config.num_last_stages, + args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, + args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_d)); + } +}; + +static void sm90_bf16_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 64 == 0); + + const auto& config = get_best_config( + GemmType::Normal, KernelType::KernelNoSF, + m, n, k, 1, major_a, major_b, + torch::kBFloat16, d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_gemm", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 64 == 0); + + const auto& config = get_best_config( + GemmType::MGroupedContiguous, KernelType::KernelNoSF, + m, n, k, 1, major_a, major_b, + torch::kBFloat16, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 64 == 0); + + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::KernelNoSF, + expected_m, n, k, num_groups, major_a, major_b, + torch::kBFloat16, d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_d = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .compiled_dims = compiled_dims, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + }; + const auto& code = SM90BF16GemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 088bf1a..3afc2d3 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -3,6 +3,7 @@ #include #include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" #include "../../jit/kernel_runtime.hpp" #include "../../utils/exception.hpp" #include "../../utils/format.hpp" @@ -139,7 +140,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::MGroupedContiguous, KernelType::Kernel1D2D, - m, n, k, num_groups, major_a, major_b, + m, n, k, 1, major_a, major_b, torch::kFloat8_e4m3fn, d.scalar_type(), false, device_runtime->get_num_sms()); @@ -185,7 +186,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons SM90FP8Gemm1D2DRuntime::launch(runtime, args); } -static void sm90_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, +static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const torch::Tensor& d, const torch::Tensor& masked_m, diff --git a/csrc/jit_kernels/impls/smxx_layout.hpp b/csrc/jit_kernels/impls/smxx_layout.hpp index 62cae90..d8a60de 100644 --- a/csrc/jit_kernels/impls/smxx_layout.hpp +++ b/csrc/jit_kernels/impls/smxx_layout.hpp @@ -10,6 +10,35 @@ namespace deep_gemm { +class TransposeFP32Runtime final: public LaunchRuntime { +public: + struct Args { + int mn, sf_k; + int block_mn; + void *sf, *out; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&transpose_fp32< + {}, {}, {} + >); +}}; +)", args.launch_args.num_threads, args.block_mn, args.sf_k); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast(args.mn))); + } +}; + class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime { public: struct Args { @@ -88,10 +117,32 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) { if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn) return (dim == 2) ? batched_sf.squeeze(0) : batched_sf; - // Normal layout requires transposing - auto aligned_sf = torch::empty_strided({num_groups, tma_aligned_mn, sf_k}, {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, batched_sf.options()); - aligned_sf = aligned_sf.slice(1, 0, mn).copy_(batched_sf); - return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf; + const auto& out = torch::empty_strided({num_groups, mn, sf_k}, + {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, + batched_sf.options()); + + if (not batched_sf.is_contiguous()) { + // Fallback to PyTorch's slow copy if not contiguous + // ReSharper disable once CppExpressionWithoutSideEffects + out.copy_(batched_sf); + } else { + constexpr int block_mn = 64; + constexpr int num_threads = 512; + const auto& smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast(sizeof(float)); + const TransposeFP32Runtime::Args& args = { + .mn = mn, + .sf_k = sf_k, + .block_mn = block_mn, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size) + }; + + const auto& code = TransposeFP32Runtime::generate(args); + const auto& runtime = compiler->build("transpose_fp32", code); + TransposeFP32Runtime::launch(runtime, args); + } + return (dim == 2) ? out.squeeze(0) : out; } static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) { @@ -127,7 +178,6 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt)); // Launch the kernel if (batched_sf.is_contiguous()) { - // Fallback to slow PyTorch impl for non-supported cases if ((mn * sf_k) % 4 != 0 and num_groups > 1) return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); @@ -146,11 +196,8 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code); TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args); } else { - // Fallback to slow PyTorch impl for non-supported cases if (mn % 4 != 0 or num_groups > 1) return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); - - DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1); DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn); constexpr int block_mn = 128; diff --git a/csrc/python_api.cpp b/csrc/python_api.cpp index 134f272..d4b210a 100644 --- a/csrc/python_api.cpp +++ b/csrc/python_api.cpp @@ -1,412 +1,19 @@ #include #include -#include "jit/compiler.hpp" -#include "jit/device_runtime.hpp" -#include "utils/layout.hpp" - -#include "jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" -#include "jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" -#include "jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp" -#include "jit_kernels/impls/smxx_layout.hpp" +#include "apis/gemm.hpp" +#include "apis/layout.hpp" +#include "apis/runtime.hpp" #ifndef TORCH_EXTENSION_NAME #define TORCH_EXTENSION_NAME deep_gemm_cpp #endif -namespace deep_gemm { -torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, - const int& mn, const int& k, - const std::tuple& recipe, - const std::optional& num_groups, - const bool& is_sfa, - const bool& disable_ue8m0_cast) { - const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe); - const auto& gran_k = std::get<2>(recipe); - const auto& arch_major = device_runtime->get_arch_major(); - - // Pre-transform checks - check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups); - - // (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major - if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) - return get_mn_major_tma_aligned_tensor(sf); - - // (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major - if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) { - DG_HOST_ASSERT(not disable_ue8m0_cast); - return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf); - } - - // (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous - if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) - return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat); - - // (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major - if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) { - DG_HOST_ASSERT(not disable_ue8m0_cast); - const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128)); - return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted); - } - - // (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major - if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10) - return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); - - DG_HOST_UNREACHABLE("Unknown SF transformation"); -} - -torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf, - const std::vector& ks, - const torch::Tensor& ks_tensor, - const std::tuple& recipe) { - DG_HOST_ASSERT(sf.dim() == 2); - DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); - const auto& arch_major = device_runtime->get_arch_major(); - - // FP32 on SM90 - if (sf.scalar_type() == torch::kFloat and arch_major == 9) - DG_HOST_UNREACHABLE("Unimplemented"); - - // FP32 on SM100 - if (sf.scalar_type() == torch::kFloat and arch_major == 10) - return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks); - - // INT on SM100 - if (sf.scalar_type() == torch::kFloat and arch_major == 10) - DG_HOST_UNREACHABLE("Unimplemented"); - - DG_HOST_UNREACHABLE("Unknown cases"); -} - -void fp8_gemm_nt(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - std::optional> recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - // Shape must be `[M, K] @ [N, K].T` - const auto& major_a = get_major_type_ab(a.first); - const auto& major_b = get_major_type_ab(b.first); - if (fp8_requires_k_major()) { - DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); - DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); - } - - // C/D must be N-major - check_major_type_cd(d); - - // Type and shape checks - const auto& [m , k ] = get_shape<2>(a.first); - const auto& [n , k_] = get_shape<2>(b.first); - const auto& [m_, n_] = get_shape<2>(d); - DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); - DG_HOST_ASSERT(n > 0 and k > 0); - DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); - - // Check C as well - if (c.has_value()) { - check_major_type_cd(c.value()); - DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); - DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); - } - - // Do nothing if the problem is empty - if (m == 0) - return; - - // Transform SFA and SFB into compute-required layout - if (not recipe.has_value()) - recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); - const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast); - const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast); - - // Dispatch into different implements - const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); - } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); - } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { - sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); - } else { - DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); - } -} - -void fp8_gemm_nn(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, - d, c, recipe, compiled_dims, disable_ue8m0_cast); -} - -void fp8_gemm_tn(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, - {b.first.transpose(0, 1), b.second.transpose(0, 1)}, - d, c, recipe, compiled_dims, disable_ue8m0_cast); -} - -void fp8_gemm_tt(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, - d, c, recipe, compiled_dims, disable_ue8m0_cast); -} - -void m_grouped_fp8_gemm_nt_contiguous(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const torch::Tensor& m_indices, - std::optional> recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - // Shape must be `[M, K] @ [G, N, K].mT` - const auto& major_a = get_major_type_ab(a.first); - const auto& major_b = get_major_type_ab(b.first); - DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); - if (fp8_requires_k_major()) - DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); - DG_HOST_ASSERT(m_indices.is_contiguous()); - - // Type and shape checks - const auto& [m, k] = get_shape<2>(a.first); - const auto& [num_groups, n, k_] = get_shape<3>(b.first); - const auto& [m_, n_] = get_shape<2>(d); - const auto& m__ = static_cast(m_indices.numel()); - DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); - DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); - DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); - DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); - - // D must be N-major - check_major_type_cd(d); - - // Do nothing if empty - if (m == 0) - return; - - // Transform SFA and SFB into compute-required layout - if (not recipe.has_value()) - recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); - const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast); - const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast); - - // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, - num_groups, m, n, k, major_a, major_b, compiled_dims); - } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices, - num_groups, m, n, k, major_a, major_b, compiled_dims); - } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { - sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, - num_groups, m, n, k, major_a, major_b, compiled_dims); - } else { - DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); - } -} - -void m_grouped_fp8_gemm_nn_contiguous(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const torch::Tensor& m_indices, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, - d, m_indices, recipe, compiled_dims, disable_ue8m0_cast); -} - -void fp8_m_grouped_gemm_nt_masked(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const torch::Tensor& masked_m, - const int& expected_m, - std::optional> recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - // Shape must be `[G, M, K] @ [G, N, K].mT` - const auto& major_a = get_major_type_ab(a.first); - const auto& major_b = get_major_type_ab(b.first); - DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - DG_HOST_ASSERT(masked_m.is_contiguous()); - - // Type and shape checks - const auto& [num_groups, m, k] = get_shape<3>(a.first); - const auto& [num_groups_, n, k_] = get_shape<3>(b.first); - const auto& [num_groups__, m_, n_] = get_shape<3>(d); - const auto& num_groups___ = static_cast(masked_m.numel()); - DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); - DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); - DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); - DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); - DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); - - // D must be N-major - check_major_type_cd(d); - - // Transform scaling factors - if (not recipe.has_value()) - recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); - const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast); - const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast); - - // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - sm90_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); - } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_fp8_m_grouped_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); - } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { - sm100_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); - } else { - DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); - } -} - -void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::vector& ks, - const torch::Tensor& ks_tensor, - const std::optional& c, - const std::tuple& recipe, - const std::string& compiled_dims) { - // Must be 1D1D kernel - DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); - - // Contiguity checks - DG_HOST_ASSERT(a.first.is_contiguous()); - DG_HOST_ASSERT(b.first.is_contiguous()); - DG_HOST_ASSERT(d.is_contiguous()); - if (c.has_value()) { - DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); - DG_HOST_ASSERT(c.value().is_contiguous()); - } - - // Do nothing if empty - if (std::accumulate(ks.begin(), ks.end(), 0) == 0) - return; - - // Transform SF with padding - const auto& [_, m] = get_shape<2>(a.first); - const auto& [__, n] = get_shape<2>(b.first); - const auto& sfa = transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe); - const auto& sfb = transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe); - - // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 10) { - fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, - cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); - } else { - DG_HOST_UNREACHABLE("Unsupported architecture"); - } -} - -} // namespace deep_gemm - // ReSharper disable once CppParameterMayBeConstPtrOrRef PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - using namespace deep_gemm; - m.doc() = "DeepGEMM C++ library"; - // Runtime - m.def("set_num_sms", [&](const int& new_num_sms) { - device_runtime->set_num_sms(new_num_sms); - }); - m.def("get_num_sms", [&]() { - return device_runtime->get_num_sms(); - }); - m.def("set_tc_util", [&](const int& new_tc_util) { - device_runtime->set_tc_util(new_tc_util); - }); - m.def("get_tc_util", [&]() { - return device_runtime->get_tc_util(); - }); - - // JIT - m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_torch) { - Compiler::prepare_init(library_root_path, cuda_home_path_by_torch); - KernelRuntime::prepare_init(cuda_home_path_by_torch); - }); - - // Stable kernel APIs with automatic arch/layout dispatch - m.def("fp8_gemm_nt", &fp8_gemm_nt, - py::arg("a"), py::arg("b"), py::arg("d"), - py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, - py::arg("compiled_dims") = "nk", - py::arg("disable_ue8m0_cast") = false); - m.def("fp8_gemm_nn", &fp8_gemm_nn, - py::arg("a"), py::arg("b"), py::arg("d"), - py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, - py::arg("compiled_dims") = "nk", - py::arg("disable_ue8m0_cast") = false); - m.def("fp8_gemm_tn", &fp8_gemm_tn, - py::arg("a"), py::arg("b"), py::arg("d"), - py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, - py::arg("compiled_dims") = "mn", - py::arg("disable_ue8m0_cast") = false); - m.def("fp8_gemm_tt", &fp8_gemm_tt, - py::arg("a"), py::arg("b"), py::arg("d"), - py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, - py::arg("compiled_dims") = "mn", - py::arg("disable_ue8m0_cast") = false); - m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous, - py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), - py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", - py::arg("disable_ue8m0_cast") = false); - m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous, - py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), - py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", - py::arg("disable_ue8m0_cast") = false); - m.def("fp8_m_grouped_gemm_nt_masked", &fp8_m_grouped_gemm_nt_masked, - py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), - py::arg("expected_m"), py::arg("recipe") = std::nullopt, - py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); - m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, - py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), - py::arg("ks_tensor"), py::arg("c") = std::nullopt, - py::arg("recipe") = std::make_tuple(1, 1, 128), - py::arg("compiled_dims") = "mn"); - - // Layout kernels - m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout, - py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"), - py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false, - py::arg("disable_ue8m0_cast") = false); - - // Raw kernels or functions - m.def("get_tma_aligned_size", &get_tma_aligned_size); - m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout); - m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor); - m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor); - m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); + deep_gemm::gemm::register_apis(m); + deep_gemm::layout::register_apis(m); + deep_gemm::runtime::register_apis(m); } diff --git a/csrc/utils/exception.hpp b/csrc/utils/exception.hpp index 10dedc0..57cc513 100644 --- a/csrc/utils/exception.hpp +++ b/csrc/utils/exception.hpp @@ -2,6 +2,7 @@ #include #include +#include namespace deep_gemm { @@ -10,7 +11,7 @@ class DGException final : public std::exception { public: explicit DGException(const char *name, const char* file, const int line, const std::string& error) { - message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'"; + message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error; } const char *what() const noexcept override { @@ -50,7 +51,11 @@ do { \ do { \ const auto& e = (cmd); \ if (e != CUDA_SUCCESS) { \ - throw DGException("CUDA driver", __FILE__, __LINE__, ""); \ + std::stringstream ss; \ + const char *name, *info; \ + cuGetErrorName(e, &name), cuGetErrorString(e, &info); \ + ss << static_cast(e) << " (" << name << ", " << info << ")"; \ + throw DGException("CUDA driver", __FILE__, __LINE__, ss.str()); \ } \ } while (0) #endif @@ -60,7 +65,9 @@ do { \ do { \ const auto& e = (cmd); \ if (e != cudaSuccess) { \ - throw DGException("CUDA runtime", __FILE__, __LINE__, std::to_string(static_cast(e))); \ + std::stringstream ss; \ + ss << static_cast(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \ + throw DGException("CUDA runtime", __FILE__, __LINE__, ss.str()); \ } \ } while (0) #endif diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index e546a30..169e2e6 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -1,6 +1,5 @@ import os -import torch -import torch.utils.cpp_extension +import subprocess # Set some default environment provided at setup try: @@ -12,14 +11,8 @@ try: except ImportError: pass -# Import functions from the CPP module -import deep_gemm_cpp -deep_gemm_cpp.init( - os.path.dirname(os.path.abspath(__file__)), # Library root directory path - torch.utils.cpp_extension.CUDA_HOME # CUDA home -) - # Configs +import deep_gemm_cpp from deep_gemm_cpp import ( set_num_sms, get_num_sms, @@ -34,13 +27,48 @@ from deep_gemm_cpp import ( fp8_gemm_tn, fp8_gemm_tt, m_grouped_fp8_gemm_nt_contiguous, m_grouped_fp8_gemm_nn_contiguous, - fp8_m_grouped_gemm_nt_masked, + m_grouped_fp8_gemm_nt_masked, k_grouped_fp8_gemm_tn_contiguous, + # BF16 GEMMs + bf16_gemm_nt, bf16_gemm_nn, + bf16_gemm_tn, bf16_gemm_tt, + m_grouped_bf16_gemm_nt_contiguous, + m_grouped_bf16_gemm_nt_masked, # Layout kernels transform_sf_into_required_layout ) +# Some alias for legacy supports +# TODO: remove these later +fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked +bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked + # Some utils from . import testing from . import utils from .utils import * + + +# Initialize CPP modules +def _find_cuda_home() -> str: + # TODO: reuse PyTorch API later + # For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks + cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + if cuda_home is None: + # noinspection PyBroadException + try: + with open(os.devnull, 'w') as devnull: + nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n') + cuda_home = os.path.dirname(os.path.dirname(nvcc)) + except Exception: + cuda_home = '/usr/local/cuda' + if not os.path.exists(cuda_home): + cuda_home = None + assert cuda_home is not None + return cuda_home + + +deep_gemm_cpp.init( + os.path.dirname(os.path.abspath(__file__)), # Library root directory path + _find_cuda_home() # CUDA home +) diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh index bada914..8ac8310 100644 --- a/deep_gemm/include/deep_gemm/common/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -11,16 +11,22 @@ enum class KGroupedIndexType { SF_K, }; -template +template static constexpr uint32_t get_num_1d_blocks_per_group() { // Select the best from candidates uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); - for (const auto& candidate: {8u, 16u}) { - const auto& usage = isMulticastOnA ? - candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N - candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M - if (usage < min_usage) - min_usage = usage, num_best_blocks = candidate; + if constexpr (kGemmType == GemmType::MGroupedContiguous or + kGemmType == GemmType::MGroupedMasked) { + // For grouped GEMMs, let weights always stay in the L2 cache and read activations by once + num_best_blocks = kNumSMs; + } else { + for (const auto& candidate: {8u, 16u}) { + const auto& usage = kIsMulticastOnA ? + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } } return num_best_blocks; } @@ -32,7 +38,7 @@ template ()> + uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group()> struct Scheduler { int current_iter = -1; diff --git a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh index 879abda..e590b47 100644 --- a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -48,7 +48,18 @@ struct FP8MMASelector { if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); } static constexpr auto select_type() { @@ -58,6 +69,71 @@ struct FP8MMASelector { using type = decltype(select_type()); }; +template +struct BF16MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct BF16MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + + template struct SM90_U32x2_STSM_N { __device__ __forceinline__ static void diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index 7851327..fc84b69 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -144,4 +144,22 @@ __device__ __forceinline__ void prefetch_l1(void *ptr) { asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); } +template +struct Vectorized { + static auto zeros() { + // TODO: add `ulonglong4` for SM100 once `__ldg` support this + if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) { + return make_uint4(0, 0, 0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) { + return make_uint2(0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) { + return 0; + } else { + DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization"); + } + } + + using vec_t = decltype(zeros()); +}; + } // namespace `deep_gemm` diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh index 28b5399..789e220 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -1,3 +1,498 @@ #pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" -// TODO: add implement \ No newline at end of file +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_c, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M; + constexpr uint32_t kNumTMAStoreStages = 2; + DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == 0) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + if constexpr (kWithAccumulation) + cute::prefetch_tma_descriptor(&tensor_map_c); + } + + // Data on shared memory (layout as ordered below) + cd_dtype_t* smem_cd[kNumTMAStoreStages]; + cutlass::bfloat16_t* smem_a[kNumStages]; + cutlass::bfloat16_t* smem_b[kNumStages]; + + // Fill D/A/B pointers + #pragma unroll + for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i) + smem_cd[i] = reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + } + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (threadIdx.x == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive only at the leader CTA + full_barriers[i]->init(kNumMulticast); + // Arrive at all CTAs + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } else if (threadIdx.x >= 32 and threadIdx.x < 64) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + uint32_t phase = 0; + auto launch_k_iterations = [&](const auto& func) { + const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K); + const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages; + + // TODO: refactor here + if (num_last_stages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1) + func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1) + func(k_iter, DivisibleK{}, false, num_last_stages); + func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1; + } + }; + + auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) { + DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2, + "Too many epilogue stages, please modify the Python heuristic as well"); + accum_stage_idx == 0 ? func(0) : func(1); + }; + + // Dispatch warps into different roles + if (warp_idx == 0) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_block_idx = k_iter * kNumStages + s; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + if (cute::elect_one_sync()) { + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx); + } + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + if (not is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(0u); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait(phase ^ 1); + if (is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(); + if (not is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(0u); + } + }); + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + // Wait tensor memory empty barrier arrival + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[s])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA arrival + full_barriers[s]->wait(phase); + tcgen05_after_thread_sync(); + + // Let tensor cores relax for lower possibility of frequency drop + DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control"); + if constexpr (kTensorCoreUtilControl < 100) { + constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; + const auto& start_clock = clock64(); + if (cute::elect_one_sync()) + while (clock64() - start_clock < kNumDummyCycles) {} + __syncwarp(); + } + + // Issue UMMA in the leader CTA + using cute_mma_t = cute::conditional_t, + cute::SM100_MMA_F16BF16_2x1SM_SS>; + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K); + cute_mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_iter > 0 or s > 0 or k > 0, + runtime_instr_desc); + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait(phase); + empty_barrier_arrive(s, false); + } + }); + }); + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32) { + // Epilogue warp groups + const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads; + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) { + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Flush TMA stores + // NOTES: for the first store, we have to flush all previous TMA, + // as we don't share pipeline stages between two blocks + if (epilogue_thread_idx == 0) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + const uint32_t iter_idx = w * kNumStores + s; + if (iter_idx >= kNumTMAStoreStages) { + if (epilogue_thread_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = iter_idx % kNumTMAStoreStages; + const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); + if (epilogue_thread_idx == 0) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + } + }); + } + + // Flush all stages in the pipeline to make TMA stores visible to the next kernel + // TODO: do we actually need this? + if (epilogue_thread_idx == 0) + cute::tma_store_wait<0>(); + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + // TODO: do we need 2 SM allocation? + if (epilogue_warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } + + // To safely deconstruct all barriers, we need a cluster sync + // TODO: optimize it by another round of barrier waits + if constexpr (kNumMulticast > 1) + cute::cluster_sync(); +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh index a78a7b1..88b6b50 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh @@ -136,7 +136,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, #pragma unroll for (uint32_t i = 0; i < kNumStages; ++ i) { // Arrive at all CTAs - full_barriers[i]->init(1); + full_barriers[i]->init(kNumMulticast); empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads / 32); } #pragma unroll @@ -241,6 +241,8 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE; if (is_leader_cta and cute::elect_one_sync()) full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + if (not is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(0u); } // Wait unaligned cases @@ -249,6 +251,8 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); if (is_leader_cta and cute::elect_one_sync()) full_barriers[s]->arrive(); + if (not is_leader_cta and cute::elect_one_sync()) + full_barriers[s]->arrive(0u); } }); } diff --git a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh index 0ccec3e..23045e1 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -1,3 +1,343 @@ #pragma once -// TODO: add implement +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == kNumMathThreads) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_bfloat16* smem_a[kNumStages]; + __nv_bfloat16* smem_b[kNumStages]; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + + // Fill shared memory pointers + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast<__nv_bfloat16*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_bfloat16*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + } + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + if (threadIdx.x == kNumMathThreads) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + struct DivisibleK {}; + struct NotDivisibleK {}; + auto launch_k_iterations = [=](const auto& func) { + if constexpr (kNumLastStages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}); + func(num_iterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 48; + constexpr uint32_t kNumMathRegisters = 224; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto divisible_type) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[s]; + uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a); + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto divisible_type) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, shifted_accum, 1); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(s); + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + // TODO: support more cases + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index c78f72d..5a65d69 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -175,7 +175,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, cutlass::arch::warpgroup_reg_dealloc(); // NOTES: only one thread (or warp) will be used - if (threadIdx.x == kNumMathThreads) { + if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) { // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { diff --git a/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh index 7385f91..bea7000 100644 --- a/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh +++ b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -4,6 +4,45 @@ namespace deep_gemm { +template +__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename Vectorized::vec_t in_vec_t; + constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); + constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; + + // Shapes and strides + extern __shared__ float smem_buffer[]; + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the block + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; + const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + + // Load + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { + auto in_vec = __ldg(local_sf + i); + const auto& in_values = reinterpret_cast(&in_vec); + + const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; + #pragma unroll + for (uint32_t j = 0; j < kNumElemsPerVec; ++ j) + smem_buffer[row * PADDED_SF_K + col + j] = in_values[j]; + } + __syncthreads(); + + // Store + #pragma unroll + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { + const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; + const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + } +} + // NOTES: the two kernels below always pack the K dimension template diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index b61abcd..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,3 +0,0 @@ -[build-system] -requires = ["torch>=2.1.0"] -build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 1c1e618..8ececfc 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,10 @@ third_party_include_dirs = [ 'third-party/cutlass/include/cutlass', ] +# Use driver API for older CUDA compatibility +if int(os.environ.get('DG_JIT_USE_DRIVER_API', '0')): + cxx_flags.append('-DDG_JIT_USE_DRIVER_API') + class CustomBuildPy(build_py): def run(self): diff --git a/tests/generators.py b/tests/generators.py index 21c050a..82cdbdc 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -59,6 +59,7 @@ def get_out_dtype() -> tuple: def get_major_ab(freeze_a: bool) -> tuple: + # TODO: test other major-ness for SM90 BF16 GEMMs if get_arch_major() == 9: return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), ) if freeze_a: @@ -70,15 +71,15 @@ def get_major_ab(freeze_a: bool) -> tuple: def enumerate_normal(use_bf16: bool = False) -> Generator: for kernel_type in get_kernel_types(use_bf16): for m in (128, 4096): - for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048), (129280, 7168)]: + for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]: for major_a, major_b in get_major_ab(False): for out_dtype in get_out_dtype(): - for accumulate in (False, ) if out_dtype == torch.bfloat16 or not kernel_type.is_1d1d() else (False, True): + for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True): yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype -def enumerate_m_grouped_contiguous() -> Generator: - for kernel_type in get_kernel_types(): +def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator: + for kernel_type in get_kernel_types(use_bf16): for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)): for major_a, major_b in get_major_ab(True): yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b @@ -106,15 +107,12 @@ def enumerate_k_grouped_contiguous(): def enumerate_sf_layout(): - for with_transpose in (True, False): - for mn in (4096, 4097, 8192): - for k in (128, 7168, 7296): - for num_groups in (1, 2, 4) if with_transpose else (1, ): - if num_groups > 1 and (mn * ceil_div(k, 128)) % 4 != 0: - continue - if not with_transpose and mn % 4 != 0: - continue - yield mn, k, with_transpose, num_groups + for use_ue8m0 in (False, True): + for with_transpose in (True, False): + for mn in (4096, 4097, 8192): + for k in (128, 7168, 7296): + for num_groups in (1, 2, 4): + yield mn, k, with_transpose, use_ue8m0, num_groups def enumerate_k_grouped_sf_layout(): @@ -126,6 +124,13 @@ def enumerate_k_grouped_sf_layout(): yield mn, ks, num_groups +def enumerate_transpose(): + for mn in (64, 4096, 16384): + for delta in (0, 101, 202, 303): + for k in (128, 1024, 4096, 9984, 16384): + yield mn + delta, k + + def generate_normal(m: int, n: int, k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, accumulate: bool, out_dtype: torch.dtype, @@ -149,8 +154,8 @@ def generate_normal(m: int, n: int, k: int, def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int, - major_a: MajorTypeAB, major_b: MajorTypeAB, use_ue8m0: bool) -> \ - Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + major_a: MajorTypeAB, major_b: MajorTypeAB, + use_ue8m0: bool = False, use_bf16: bool = False): actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] m = sum(aligned_ms) @@ -171,6 +176,10 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: start = aligned_end ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d) + if use_bf16: + b = b if major_b.is_k_major() else b.mT.contiguous().mT + return m, a, b, m_indices, d, ref_d + assert major_a.is_k_major() a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0) b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), @@ -181,24 +190,27 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: return m, a_fp8, b_fp8, m_indices, d, ref_d -def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, use_ue8m0: bool) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: +def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, + use_ue8m0: bool = False, use_bf16: bool = False): a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) ref_d = torch.einsum('gmk,gnk->gmn', a, b) + masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + for j in range(num_groups): + masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + assert masked_m.amax().item() <= max_m + + if use_bf16: + return a, b, masked_m, d, ref_d + a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float)) b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) for i in range(num_groups): a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) - masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) - for j in range(num_groups): - masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) - assert masked_m.amax().item() <= max_m - return a_fp8, b_fp8, masked_m, d, ref_d diff --git a/tests/test_bf16.py b/tests/test_bf16.py new file mode 100644 index 0000000..790f700 --- /dev/null +++ b/tests/test_bf16.py @@ -0,0 +1,125 @@ +import torch +import random + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes +) +from generators import ( + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, generate_normal, + generate_m_grouped_contiguous, generate_m_grouped_masked +) + + +def test_gemm() -> None: + print('Testing GEMM:') + for _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(use_bf16=True): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + + for test_alias in (False, True): + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True) + func_name = f'bf16_gemm_{major_opt.lower() if test_alias else "nt"}' + if test_alias: + a = a if major_a.is_k_major() else a.T + b = b if major_b.is_k_major() else b.T + assert a.is_contiguous() and b.is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, c=c) + diff = calc_diff(d, ref_d) + assert diff < 0.0001, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + f'{diff:.5f}, alias={test_alias}') + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True) + + cublas_t = 0 + t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True) + if accumulate == 0 and out_dtype == torch.bfloat16: + # noinspection PyBroadException + try: + cublas_t = bench_kineto(lambda: a @ b.T, 'nvjet', suppress_kineto_output=True) + except Exception: + pass + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' + f'{cublas_t / t:.2f}x cuBLAS') + print() + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing m-grouped contiguous GEMM:') + + for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(use_bf16=True): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + + for test_alias in (False, True): + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + func_name = f"m_grouped_bf16_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else b.mT + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, m_indices) + d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, m_indices) + + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(): + # Test correctness + for i in range(10): + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True) + deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + for j in range(num_groups): + diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' + + # Construct full cases + a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() diff --git a/tests/test_fp8.py b/tests/test_fp8.py index 289d20e..0c7d3ce 100644 --- a/tests/test_fp8.py +++ b/tests/test_fp8.py @@ -105,7 +105,7 @@ def test_m_grouped_gemm_masked() -> None: # Test correctness for i in range(10): a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - deep_gemm.fp8_m_grouped_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) for j in range(num_groups): diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' @@ -115,7 +115,7 @@ def test_m_grouped_gemm_masked() -> None: # noinspection PyShadowingNames def test_func(): - deep_gemm.fp8_m_grouped_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) + deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) # Test performance with fixed shapes valid_m = masked_m.sum().item() diff --git a/tests/test_layout.py b/tests/test_layout.py index 6cad642..42d7208 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -1,16 +1,18 @@ import time import torch import random -from deep_gemm.testing import bench_kineto, count_bytes +from deep_gemm.testing import bench_kineto, count_bytes, calc_diff from deep_gemm.utils import ( align, ceil_div, per_token_cast_to_fp8, per_channel_cast_to_fp8, get_tma_aligned_size, + get_mn_major_tma_aligned_tensor, get_mn_major_tma_aligned_packed_ue8m0_tensor, get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor ) from generators import ( + enumerate_transpose, enumerate_sf_layout, enumerate_k_grouped_sf_layout ) @@ -43,29 +45,39 @@ def get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(x: torch.Tensor) -> def test_sf_layout_kernels() -> None: print('Testing SF layout kernels:') - for mn, k, with_transpose, num_groups in enumerate_sf_layout(): + for mn, k, with_transpose, use_ue8m0, num_groups in enumerate_sf_layout(): x = torch.randn((num_groups * mn, k), dtype=torch.bfloat16, device='cuda') - x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=True) + x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0) fp32_sf = fp32_sf if num_groups == 1 else fp32_sf.view(num_groups, mn, -1) fp32_sf = fp32_sf if with_transpose else fp32_sf.transpose(-1, -2).contiguous().transpose(-1, -2) # Correctness - packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf) - ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf) - assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}' - assert packed_sf.shape == ref_packed_sf.shape - assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())]) - - # Test launch overhead - launch_start_t = time.time_ns() - get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf) - launch_end_t = time.time_ns() + if use_ue8m0: + impl, name = get_mn_major_tma_aligned_packed_ue8m0_tensor, 'pack_fp32_into_ue8m0' + packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf) + ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf) + assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}' + assert packed_sf.shape == ref_packed_sf.shape + assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())]) + else: + impl, name = get_mn_major_tma_aligned_tensor, 'transpose' + transposed_sf = get_mn_major_tma_aligned_tensor(fp32_sf) + tma_aligned_mn, sf_k = get_tma_aligned_size(mn, fp32_sf.element_size()), ceil_div(k, 128) + if num_groups > 1: + assert transposed_sf.size(0) == num_groups + assert transposed_sf.stride(0) == tma_aligned_mn * sf_k + assert transposed_sf.shape[-2:] == (mn, sf_k) + assert transposed_sf.stride()[-2:] == (1, tma_aligned_mn) + assert torch.equal(fp32_sf, transposed_sf) # Performance - t = bench_kineto(lambda: get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf), 'pack_fp32_into_ue8m0') - print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}): ' - f'launch {(launch_end_t - launch_start_t) / 1e3:3.0f} us | {t * 1e6:4.0f} us | ' - f'{count_bytes(fp32_sf, packed_sf) / 1e9 / t:4.0f} GB/s') + try: + t = bench_kineto(lambda: impl(fp32_sf), name) + except AssertionError as e: + # Some cases may fallback to PyTorch impl + t = 0 + print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}, use_ue8m0={int(use_ue8m0)}): ' + f'{t * 1e6:4.0f} us | {count_bytes(fp32_sf, impl(fp32_sf)) / 1e9 / t if t else 0:4.0f} GB/s') print() diff --git a/tests/test_lazy_init.py b/tests/test_lazy_init.py new file mode 100644 index 0000000..5363b6d --- /dev/null +++ b/tests/test_lazy_init.py @@ -0,0 +1,15 @@ +import torch +import torch.multiprocessing as mp +import deep_gemm + + +def main(local_rank: int): + torch.cuda.set_device(local_rank) + + +if __name__ == '__main__': + procs = [mp.Process(target=main, args=(i, ), ) for i in range(8)] + for p in procs: + p.start() + for p in procs: + p.join()