Make various updates and fixes (#198)

This commit is contained in:
Ray Wang
2025-09-25 16:19:07 +08:00
committed by GitHub
parent 79f48ee15a
commit 3f71de7aa9
45 changed files with 3281 additions and 1060 deletions

77
csrc/apis/attention.hpp Normal file
View File

@@ -0,0 +1,77 @@
#pragma once
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
#include "layout.hpp"
namespace deep_gemm::attention {
static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::tuple<int, int, int> &head_splits,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
if (fp8_requires_k_major()) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
}
// D must be N-major
check_major_type_cd(d);
// Type and shape checks
const auto& [m , k ] = get_shape<2>(a.first);
const auto& [n , k_] = get_shape<2>(b.first);
const auto& [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Check head splits and N
const auto& [left, mid, right] = head_splits;
DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid);
// Do nothing if the problem is empty
if (m == 0)
return;
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
const auto& epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) {
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
static void register_apis(pybind11::module_& m) {
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"),
py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
}
} // namespace deep_gemm::attention

115
csrc/apis/einsum.hpp Normal file
View File

@@ -0,0 +1,115 @@
#pragma once
#include <pybind11/pybind11.h>
#include <torch/python.h>
#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/layout.hpp"
#include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/smxx_cublaslt.hpp"
namespace deep_gemm::einsum {
static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d,
const std::optional<torch::Tensor>& c) {
// Currently FP32 only support the accumulated expression
if (d.scalar_type() == torch::kFloat) {
DG_HOST_ASSERT(c->data_ptr() == d.data_ptr() and c->sizes() == d.sizes() and c->strides() == d.strides());
} else {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(not c.has_value());
const auto& workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32));
DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(),
c10::cuda::getCurrentCUDAStream()));
bmk_bnk_mn(a, b, workspace, workspace);
// This line has an implicit FP32-to-BF16 casting
d.copy_(workspace);
return;
}
DG_HOST_ASSERT(a.is_contiguous());
DG_HOST_ASSERT(b.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
const auto& [s , m, k ] = get_shape<3>(a);
const auto& [s_, n, k_] = get_shape<3>(b);
DG_HOST_ASSERT(s == s_ and k == k_);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else if (arch_major == 10) {
sm100_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) {
const auto& [b , h , r ] = get_shape<3>(A);
const auto& [h_, d , r_] = get_shape<3>(B);
const auto& [b_, h__, d_] = get_shape<3>(D);
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d);
}
static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) {
const auto& [b , h , d ] = get_shape<3>(A);
const auto& [h_, d_ , r ] = get_shape<3>(B);
const auto& [b_, h__, r_] = get_shape<3>(D);
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d);
}
static void einsum(const std::string& expr,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c) {
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
if (c.has_value()) {
DG_HOST_ASSERT(c->scalar_type() == torch::kFloat);
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
}
// Some hardcoded Einstein sum kernels
// TODO: support any expression
// TODO: canonicalize expression
if (expr == "bmk,bnk->mn") {
bmk_bnk_mn(a, b, d, c);
} else if (expr == "bhr,hdr->bhd") {
DG_HOST_ASSERT(not c.has_value());
bhr_hdr_bhd(a, b, d);
} else if (expr == "bhd,hdr->bhr") {
DG_HOST_ASSERT(not c.has_value());
bhd_hdr_bhr(a, b, d);
} else {
DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr));
}
}
static void register_apis(pybind11::module_& m) {
m.def("einsum", &einsum,
py::arg("expr"), py::arg("a"), py::arg("b"),
py::arg("d"), py::arg("c") = std::nullopt);
}
} // namespace deep_gemm::einsum

View File

@@ -1,5 +1,6 @@
#pragma once
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
@@ -52,13 +53,18 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
if (std::get<1>(recipe.value()) == 1) {
sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else {
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
}
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
@@ -261,6 +267,60 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
}
}
static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
const std::optional<torch::Tensor>& c,
const std::tuple<int, int, int>& recipe,
const std::string& compiled_dims) {
// Must be 1D1D kernel
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& sum_mk = a.first.numel();
const auto& sum_nk = b.first.numel();
int sum_k = 0;
for (const auto& k: ks)
sum_k += k;
DG_HOST_ASSERT(sum_mk == m * sum_k);
DG_HOST_ASSERT(sum_nk == n * sum_k);
// Contiguity checks
DG_HOST_ASSERT(a.first.is_contiguous());
DG_HOST_ASSERT(b.first.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
if (c.has_value()) {
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
DG_HOST_ASSERT(c.value().is_contiguous());
}
// Do nothing if empty
if (std::accumulate(ks.begin(), ks.end(), 0) == 0)
return;
// Transform SF with padding
const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
// Allocate tensormap buffer
// `4` means the double buffering for both A and B operands (2 * 2)
const auto& num_sms = device_runtime->get_num_sms();
const auto& tensor_map_buffer = torch::empty({num_sms * 4 * static_cast<int>(sizeof(CUtensorMap))},
a.first.options().dtype(torch::kByte));
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void bf16_gemm_nt(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
@@ -403,6 +463,43 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T
}
}
static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a);
const auto& major_b = get_major_type_ab(b);
// Type and shape checks
const auto& [m , k ] = get_shape<2>(a);
const auto& [n , k_] = get_shape<2>(b);
const auto& [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
if (c.has_value())
DG_HOST_ASSERT(c.value().scalar_type() == d.scalar_type());
// Do nothing if the problem is empty
if (m == 0 or n == 0)
return;
cublaslt_gemm(a, b, c, d, m, n, k, major_a, major_b);
}
static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
cublaslt_gemm_nt(a, b.transpose(0, 1), d, c);
}
static void cublaslt_gemm_tn(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
cublaslt_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c);
}
static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
cublaslt_gemm_nt(a.transpose(0, 1), b, d, c);
}
static void register_apis(pybind11::module_& m) {
// FP8 GEMMs
m.def("fp8_gemm_nt", &fp8_gemm_nt,
@@ -442,6 +539,11 @@ static void register_apis(pybind11::module_& m) {
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 1, 128),
py::arg("compiled_dims") = "mn");
m.def("k_grouped_fp8_gemm_nt_contiguous", &k_grouped_fp8_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 1, 128),
py::arg("compiled_dims") = "mn");
// BF16 GEMMs
m.def("bf16_gemm_nt", &bf16_gemm_nt,
@@ -466,6 +568,16 @@ static void register_apis(pybind11::module_& m) {
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
// cuBLASLt GEMMs
m.def("cublaslt_gemm_nt", &cublaslt_gemm_nt,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
m.def("cublaslt_gemm_nn", &cublaslt_gemm_nn,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
m.def("cublaslt_gemm_tn", &cublaslt_gemm_tn,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
m.def("cublaslt_gemm_tt", &cublaslt_gemm_tt,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
}
} // namespace deep_gemm::gemm

View File

@@ -56,14 +56,14 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
// FP32 on SM90
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
DG_HOST_UNREACHABLE("Unimplemented");
return get_mn_major_tma_aligned_tensor(sf);
// FP32 on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
// INT on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
if (sf.scalar_type() == torch::kInt and arch_major == 10)
DG_HOST_UNREACHABLE("Unimplemented");
DG_HOST_UNREACHABLE("Unknown cases");

View File

@@ -1,9 +1,16 @@
// GEMM kernels
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
// Einsum kernels
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
// Layout kernels
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;

View File

@@ -1,5 +1,6 @@
#pragma once
#include <cublasLt.h>
#include <ATen/cuda/CUDAContext.h>
#include "../utils/exception.hpp"
@@ -11,8 +12,28 @@ class DeviceRuntime {
int num_sms = 0, tc_util = 0;
std::shared_ptr<cudaDeviceProp> cached_prop;
// cuBLASLt utils
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
cublasLtHandle_t cublaslt_handle{};
std::shared_ptr<torch::Tensor> cublaslt_workspace;
public:
explicit DeviceRuntime() = default;
explicit DeviceRuntime() {
cublaslt_workspace = std::make_shared<torch::Tensor>(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)));
DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle));
}
~DeviceRuntime() noexcept(false) {
DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle));
}
cublasLtHandle_t get_cublaslt_handle() const {
return cublaslt_handle;
}
torch::Tensor get_cublaslt_workspace() const {
return *cublaslt_workspace;
}
std::shared_ptr<cudaDeviceProp> get_prop() {
if (cached_prop == nullptr)

View File

@@ -1,5 +1,7 @@
#pragma once
#include <deep_gemm/common/types.hpp>
#include "../../utils/math.hpp"
#include "../../utils/layout.hpp"
@@ -80,18 +82,19 @@ static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
return divisible and num_sms % num_multicast == 0;
}
static int get_swizzle_mode(const int& block_size, const int& elem_size) {
template <typename size_type_t>
static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) {
// `> 0` means interleaving
// 16B actually means non-swizzling (but interleaving)
for (const int& mode: {128, 64, 32, 16}) {
if ((block_size * elem_size) % mode == 0)
if ((block_size * static_cast<int>(elem_size)) % mode == 0)
return mode;
}
DG_HOST_UNREACHABLE("Unreachable");
}
template <typename ArchSpec>
static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const KernelType& kernel_type,
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
@@ -104,7 +107,7 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n);
const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size);
const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size);
const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size);
const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0;
// Different archs have different epilogue pipelines
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);
@@ -121,9 +124,11 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
// M-barriers and tensor memory pointers
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages);
const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size();
const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type);
// Sum them up
int smem_size = 0;
smem_size += smem_tensor_map;
smem_size += smem_cd;
smem_size += num_stages * smem_a_per_stage;
smem_size += num_stages * smem_b_per_stage;
@@ -151,15 +156,12 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
// Select M/N block sizes
// TODO: support `% 16 == 8` block size on SM90
auto block_ms = std::vector{64, 128, 256};
if (gemm_type == GemmType::MGroupedContiguous)
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
block_ms = std::vector{64, 128};
std::vector<int> block_ns;
for (int i = 16; i <= 256; i += 16)
block_ns.push_back(i);
const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype);
// K block size is selected in a fixed manner
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
@@ -214,9 +216,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0);
// Decide the number of TMA multicasts and whether broadcast on A
MulticastConfig best_multicast_config = {1, true};
MulticastConfig best_multicast_config = {1, false};
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
gemm_type, m, n, best_block_m, best_block_n, num_sms);
gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms);
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
bool order[2] = {false, true};
if (best_block_m > best_block_n)
@@ -232,11 +234,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
constexpr int smem_capacity = ArchSpec::smem_capacity;
int best_num_stages = 0;
SharedMemoryConfig best_smem_config;
for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) {
for (int num_stages = 12; num_stages > 0; -- num_stages) {
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
continue;
best_smem_config = get_smem_config<ArchSpec>(kernel_type,
best_smem_config = get_smem_config<ArchSpec>(gemm_type, kernel_type,
m, n, k,
best_block_m, best_block_n, block_k,
major_a, major_b,

View File

@@ -12,6 +12,15 @@ namespace deep_gemm {
struct SM100ArchSpec {
static constexpr int smem_capacity = 232448;
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
// 16 is for better SM usage
// Stride 32 is due to low-performance swizzle-16/32B
std::vector<int> candidates = {16};
for (int i = 32; i <= 256; i += 32)
candidates.push_back(i);
return candidates;
}
static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) {
return block_m / (config.is_multicast_on_a ? config.num_multicast : 1);
}
@@ -29,6 +38,10 @@ struct SM100ArchSpec {
return block_n;
}
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
return true;
}
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) {
constexpr int num_utccp_aligned_elems = 128;
@@ -86,7 +99,7 @@ struct SM100ArchSpec {
return false;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
// TODO: support other layouts
@@ -138,12 +151,17 @@ struct SM100ArchSpec {
// TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers
// NOTES: 1D2D kernel will not use the with-SF full barriers
// NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
return num_stages * 8 * 3 + 2 * 8 * 2;
// NOTES: the last barrier is for tensor core utilization control
return num_stages * 8 * 3 + 2 * 8 * 2 + 8;
}
static int get_tmem_ptr_smem_size() {
return 4;
}
static int get_tensormap_smem_size(const GemmType& gemm_type) {
return 0;
}
};
} // namespace deep_gemm

View File

@@ -11,6 +11,15 @@ namespace deep_gemm {
struct SM90ArchSpec {
static constexpr int smem_capacity = 232448;
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
// Avoid bank conflicts for FP32 output
const auto& start = cd_dtype == torch::kFloat ? 8 : 16;
std::vector<int> candidates;
for (int i = start; i <= 256; i += 16)
candidates.push_back(i);
return candidates;
}
static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) {
return block_m;
}
@@ -19,26 +28,35 @@ struct SM90ArchSpec {
return block_n;
}
static int get_cd_store_block_m(const int& block_m) {
return block_m;
static int get_cd_store_block_m(const int& block_m, const bool& single_warpgroup_sync = false) {
constexpr int wgmma_m = 64;
return single_warpgroup_sync ? wgmma_m : block_m;
}
static int get_cd_store_block_n(const int& block_n) {
return block_n;
}
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
return cd_dtype != torch::kFloat;
}
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& block_m, const int& block_n, const int& block_k) {
// FP32 output does not support `block_m == 256`
// SM90 FP32 output does not support `block_m == 256`
if (cd_dtype == at::kFloat and block_m == 256)
return false;
// TODO: more general block N selection
// Must be some fixed block N selections
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 and block_n != 152))
return false;
// Avoid large C/D shared memory for FP32 output
// Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel)
if (block_n > 128 and cd_dtype == torch::kFloat) {
if (kernel_type == KernelType::Kernel1D1D and block_n > 152)
return false;
if (kernel_type == KernelType::KernelNoSF and block_n > 200)
return false;
}
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
// Or too many register spills
@@ -66,9 +84,13 @@ struct SM90ArchSpec {
return true;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
// Disable multicast when the number of k-groups is large (a heuristic)
if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4)
return {false, false};
return {
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
@@ -96,9 +118,10 @@ struct SM90ArchSpec {
int smem_sfa_per_stage = block_m * static_cast<int>(sizeof(float));
int smem_sfb_per_stage = 0;
// TODO: figure out here
if (kernel_type == KernelType::Kernel1D1D)
smem_sfb_per_stage = align(block_n * 4, block_k);
if (kernel_type == KernelType::Kernel1D1D) {
// NOTES: `128` is for 2D TMA alignment requirement
smem_sfb_per_stage = align(block_n * 4, 128);
}
return {smem_sfa_per_stage, smem_sfb_per_stage};
}
@@ -109,13 +132,16 @@ struct SM90ArchSpec {
}
static int get_barrier_smem_size(const int& num_stages) {
// For 1D1D kernels, there is an extra barrier for accumulation
return (num_stages + 1) * 8 * 2;
return num_stages * 8 * 2;
}
static int get_tmem_ptr_smem_size() {
return 0;
}
static int get_tensormap_smem_size(const GemmType& gemm_type) {
return gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast<int>(sizeof(CUtensorMap)) : 0;
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,12 @@
#pragma once
#include <optional>
#include <string>
namespace deep_gemm {
static std::string get_default_epilogue_type(const std::optional<std::string>& epilogue_type) {
return epilogue_type.value_or("EpilogueIdentity");
}
} // namespace deep_gemm

View File

@@ -4,6 +4,8 @@
#include <torch/python.h>
#include "../../utils/math.hpp"
#include "../heuristics/sm90.hpp"
#include "../../utils/system.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
@@ -51,7 +53,11 @@ static std::string to_string(const at::ScalarType& dtype) {
}
}
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) {
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype,
const bool& allow_tf32) {
if (allow_tf32 and dtype == torch::kFloat)
return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32;
switch (dtype) {
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
@@ -61,9 +67,14 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
}
}
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) {
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
if (base != 0) {
DG_HOST_ASSERT(base == 32 and mode == 128);
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
}
switch (mode) {
case 0: return CU_TENSOR_MAP_SWIZZLE_NONE;
case 0:
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
@@ -76,7 +87,8 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
int gmem_inner_dim, int gmem_outer_dim,
int smem_inner_dim, int smem_outer_dim,
const int& gmem_outer_stride,
const int& swizzle_mode) {
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_inner_dim = swizzle_mode / elem_size;
@@ -87,14 +99,42 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
const cuuint32_t elem_strides[2] = {1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n",
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d\n",
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
gmem_outer_stride, swizzle_mode, elem_size);
gmem_outer_stride, swizzle_mode, swizzle_base, elem_size);
}
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()),
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode),
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensor_map;
}
static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
const int& gmem_dim_0, const int& gmem_dim_1, const int& gmem_dim_2,
const int& smem_dim_0, const int& smem_dim_1, const int& smem_dim_2,
const int& gmem_stride_0, const int& gmem_stride_1,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
DG_HOST_ASSERT(smem_dim_0 == swizzle_mode / elem_size);
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[3] = {static_cast<cuuint64_t>(gmem_dim_0), static_cast<cuuint64_t>(gmem_dim_1), static_cast<cuuint64_t>(gmem_dim_2),};
const cuuint32_t smem_dims[3] = {static_cast<cuuint32_t>(smem_dim_0), static_cast<cuuint32_t>(smem_dim_1), static_cast<cuuint32_t>(smem_dim_2)};
const cuuint64_t gmem_strides[2] = {static_cast<cuuint64_t>(gmem_stride_0 * elem_size), static_cast<cuuint64_t>(gmem_stride_1 * elem_size)};
const cuuint32_t elem_strides[3] = {1, 1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making 3D TMA desc: global memory: %d %d %d, shared memory: %d %d %d, outer stride: %d %d, swizzle: %d, elem size: %d\n",
gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2,
gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size);
}
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensor_map;
}
@@ -105,7 +145,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
const int& block_m, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
if (num_groups > 1)
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
@@ -114,7 +155,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
gmem_inner_dim, gmem_outer_dim,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode);
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
@@ -123,7 +165,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
const int& block_n, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
@@ -132,7 +175,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
gmem_inner_dim, gmem_outer_dim * num_groups,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode);
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
@@ -140,15 +184,16 @@ static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
const int& block_m, const int& block_n,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
return make_tma_2d_desc(t,
shape_n, shape_m * num_groups,
block_n, block_m,
outer_stride,
swizzle_mode);
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
@@ -156,7 +201,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
int shape_mn, int shape_k,
const int& block_mn, const int& block_k,
const int& num_groups,
const int& swizzle_mode) {
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
// TODO: maybe swizzle SF as well
@@ -167,7 +213,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
block_mn, 1,
shape_mn,
swizzle_mode);
swizzle_mode, swizzle_base,
allow_tf32);
}
} // namespace deep_gemm

View File

@@ -42,7 +42,7 @@ static void __instantiate_kernel() {{
{}, {}, {},
{},
{}, {}, {},
{}, {},
{},
{}, {},
{}, {},
{},
@@ -56,7 +56,7 @@ static void __instantiate_kernel() {{
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.num_stages,
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
@@ -80,8 +80,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
// TODO: test other Ks
DG_HOST_ASSERT(k % 64 == 0);
const auto& aligned_k = align(k, 64);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::KernelNoSF,
m, n, k, 1, major_a, major_b,
@@ -122,7 +121,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
// Launch
const SM100BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,

View File

@@ -0,0 +1,137 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100BmkBnkMnRuntime final: public LaunchRuntime<SM100BmkBnkMnRuntime> {
public:
struct Args {
int s, m, n, k;
int block_m, block_n, block_k;
int split_factor;
int swizzle_ab_mode, swizzle_cd_mode;
int num_stages;
int num_threads;
LaunchArgs launch_args;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_bmn_bnk_mn_gemm_impl<
{}, {}, {},
{}, {}, {},
{},
{}, {},
{}, {}
>);
}};
)",
args.m, args.n, args.k,
args.block_m, args.block_n, args.block_k,
args.split_factor,
args.swizzle_ab_mode, args.swizzle_cd_mode,
args.num_stages, args.num_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.s, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d));
}
};
static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a,
const torch::Tensor &b,
const torch::Tensor &d,
const int &s, const int &m, const int &n, const int &k) {
constexpr int block_m = 128;
constexpr int block_n = 128;
constexpr int block_k = 64;
constexpr int num_threads = 128;
DG_HOST_ASSERT(k % block_k == 0);
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
const int swizzle_cd_mode = get_swizzle_mode(block_n, static_cast<int>(d.element_size()));
// Get best config
const int num_sms = device_runtime->get_num_sms();
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
const int num_sk_blocks = s * (k / block_k);
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
// Select best number of stages
// NOTES: we select 4 as start, as it is tested to be faster than values > 4
int num_stages = 4, smem_size = 0;
while (true) {
const int& smem_cd = block_m * swizzle_cd_mode * 2;
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_barrier = SM100ArchSpec::get_barrier_smem_size(num_stages);
const int& smem_tmem_ptr = SM100ArchSpec::get_tmem_ptr_smem_size();
smem_size = 0;
smem_size += smem_cd;
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
smem_size += smem_barrier;
smem_size += smem_tmem_ptr;
if (smem_size <= SM100ArchSpec::smem_capacity)
break;
-- num_stages;
}
DG_HOST_ASSERT(num_stages > 0);
// Print configs
if (get_env("DG_JIT_DEBUG", 0)) {
printf("S: %d, M: %d, N: %d, K: %d -> "
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
"stages: %d, shared memory: %d, swizzle AB: %d, swizzle CD: %d\n",
s, m, n, k, block_m, block_n, block_k, split_factor,
num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode);
}
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
const auto& tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode);
const SM100BmkBnkMnRuntime::Args& args = {
.s = s, .m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.split_factor = split_factor,
.swizzle_ab_mode = swizzle_ab_mode,
.swizzle_cd_mode = swizzle_cd_mode,
.num_stages = num_stages,
.num_threads = num_threads,
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100BmkBnkMnRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code);
SM100BmkBnkMnRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -9,6 +9,8 @@
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "epilogue.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
@@ -18,6 +20,7 @@ public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;
GemmConfig gemm_config;
LaunchArgs launch_args;
@@ -44,11 +47,12 @@ static void __instantiate_kernel() {{
{}, {}, {},
{},
{}, {}, {},
{}, {},
{},
{}, {},
{}, {},
{},
{}, {}, {}
{}, {}, {},
{}
>);
}};
)",
@@ -57,11 +61,12 @@ static void __instantiate_kernel() {{
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.num_stages,
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype));
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
get_default_epilogue_type(args.epilogue_type));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -80,7 +85,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
@@ -99,7 +105,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
@@ -129,6 +135,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -186,6 +193,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -244,6 +252,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -324,6 +333,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,

View File

@@ -18,6 +18,7 @@ public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;
GemmConfig gemm_config;
LaunchArgs launch_args;
@@ -46,7 +47,8 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{},
{}, {}
{}, {},
{}
>);
}};
)",
@@ -59,7 +61,8 @@ static void __instantiate_kernel() {{
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype));
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype),
get_default_epilogue_type(args.epilogue_type));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -78,7 +81,8 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
DG_HOST_ASSERT(not c.has_value());
const auto& aligned_k = align(k, 128);
@@ -98,7 +102,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
@@ -111,6 +115,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -164,6 +169,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -218,6 +224,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const t
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,

View File

@@ -41,7 +41,7 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{}, {},
{}, {}
{}, {}, {}
>);
}};
)",
@@ -53,7 +53,8 @@ static void __instantiate_kernel() {{
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
to_string(args.gemm_config.cd_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -73,10 +74,10 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(not c.has_value());
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 64 == 0);
const auto& aligned_k = align(k, 64);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::KernelNoSF,
m, n, k, 1, major_a, major_b,
@@ -102,7 +103,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
// Launch
const SM90BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,

View File

@@ -0,0 +1,131 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90BmkBnkMnRuntime final: public LaunchRuntime<SM90BmkBnkMnRuntime> {
public:
struct Args {
int s, m, n, k;
int block_m, block_n, block_k;
int split_factor;
int num_stages;
int num_tma_threads, num_math_threads;
LaunchArgs launch_args;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
float* d;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_bmn_bnk_mn_gemm_impl<
{}, {}, {},
{}, {}, {},
{},
{},
{}, {}
>);
}};
)",
args.m, args.n, args.k,
args.block_m, args.block_n, args.block_k,
args.split_factor,
args.num_stages,
args.num_tma_threads, args.num_math_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.s, args.tensor_map_a, args.tensor_map_b, args.d));
}
};
static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a,
const torch::Tensor &b,
const torch::Tensor &d,
const int &s, const int &m, const int &n, const int &k) {
constexpr int block_m = 128;
constexpr int block_n = 128;
constexpr int block_k = 64;
constexpr int num_tma_threads = 128;
constexpr int num_math_threads = 256;
DG_HOST_ASSERT(k % block_k == 0);
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
DG_HOST_ASSERT(swizzle_ab_mode == 128);
// Get best config
const int num_sms = device_runtime->get_num_sms();
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
const int num_sk_blocks = s * (k / block_k);
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
// Select best number of stages
int num_stages = 4, smem_size = 0;
while (true) {
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_barrier = SM90ArchSpec::get_barrier_smem_size(num_stages);
smem_size = 0;
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
smem_size += smem_barrier;
if (smem_size <= SM90ArchSpec::smem_capacity)
break;
-- num_stages;
}
DG_HOST_ASSERT(num_stages > 0);
// Print configs
if (get_env("DG_JIT_DEBUG", 0)) {
printf("S: %d, M: %d, N: %d, K: %d -> "
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
"stages: %d, shared memory: %d, swizzle AB: %d\n",
s, m, n, k, block_m, block_n, block_k, split_factor,
num_stages, smem_size, swizzle_ab_mode);
}
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
const SM90BmkBnkMnRuntime::Args& args = {
.s = s, .m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.split_factor = split_factor,
.num_stages = num_stages,
.num_tma_threads = num_tma_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.d = d.data_ptr<float>()
};
const auto& code = SM90BmkBnkMnRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code);
SM90BmkBnkMnRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,214 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime<SM90FP8Gemm1D1DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmConfig gemm_config;
LaunchArgs launch_args;
void *gmem_a_ptr;
void *gmem_b_ptr;
void *grouped_layout;
void *tensor_map_buffer;
CUtensorMap tensor_map_a_base;
CUtensorMap tensor_map_b_base;
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_d;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d1d_impl<
{}, {}, {},
{},
{}, {}, {},
{},
{}, {},
{}, {},
{},
{}, {}
>);
}};
)",
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.num_groups,
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.gemm_config.num_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
to_string(args.gemm_config.cd_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.gmem_a_ptr, args.gmem_b_ptr,
args.grouped_layout,
args.tensor_map_buffer,
args.m, args.n, args.k,
args.tensor_map_a_base, args.tensor_map_b_base,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_d));
}
};
static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k, k, 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k, k, 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, 1, 0);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
0);
// Launch
const SM90FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.gmem_a_ptr = nullptr,
.gmem_b_ptr = nullptr,
.grouped_layout = nullptr,
.tensor_map_buffer = nullptr,
.tensor_map_a_base = tensor_map_a,
.tensor_map_b_base = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_d = tensor_map_d,
};
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
}
static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const torch::Tensor& tensor_map_buffer,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
// Get config using max K for better performance
const auto& num_groups = static_cast<int>(ks.size());
const auto& max_k = *std::max_element(ks.begin(), ks.end());
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
m, n, max_k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
int first_k = 0, sum_k = 0, sum_sf_k = 0;
for (int i = 0; i < num_groups; ++ i) {
if (first_k == 0 and ks[i] != 0)
first_k = ks[i];
sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128);
DG_HOST_ASSERT(ks[i] % 128 == 0);
}
const auto& tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k, first_k, 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k, first_k, 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128,
config.block_n, config.block_k, 1, 0);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
// Launch
const SM90FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.gmem_a_ptr = a.data_ptr(),
.gmem_b_ptr = b.data_ptr(),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_buffer = tensor_map_buffer.data_ptr(),
.tensor_map_a_base = tensor_map_a_base,
.tensor_map_b_base = tensor_map_b_base,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_d = tensor_map_d,
};
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -8,6 +8,8 @@
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../heuristics/sm90.hpp"
#include "epilogue.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
@@ -17,6 +19,7 @@ public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;
GemmConfig gemm_config;
LaunchArgs launch_args;
@@ -43,7 +46,7 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{}, {},
{}, {}
{}, {}, {}
>);
}};
)",
@@ -55,7 +58,8 @@ static void __instantiate_kernel() {{
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
get_default_epilogue_type(args.epilogue_type));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -74,7 +78,8 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
@@ -98,7 +103,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
@@ -111,6 +116,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -170,6 +176,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -230,6 +237,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,

View File

@@ -0,0 +1,151 @@
#pragma once
#include <cublasLt.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
namespace deep_gemm {
static auto get_cublaslt_layout(const cudaDataType& type, const int& rows, const int& cols, const int& ld,
const std::optional<int>& batch_count = std::nullopt,
const std::optional<int>& batch_offset = std::nullopt) {
cublasLtMatrixLayout_t layout;
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&layout, type, rows, cols, ld));
if (batch_count.has_value()) {
DG_HOST_ASSERT(batch_offset.has_value());
const int64_t batch_offset_int64 = batch_offset.value();
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count.value(), sizeof(batch_count.value())));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset_int64, sizeof(batch_offset_int64)));
}
return layout;
}
static void call_cublaslt_api(const cublasOperation_t& trans_a,
const cublasOperation_t& trans_b,
const cublasLtMatrixLayout_t& layout_a,
const cublasLtMatrixLayout_t& layout_b,
const cublasLtMatrixLayout_t& layout_d,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const bool& accumulate) {
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
cudaDataType_t scale_type = CUDA_R_32F;
const int& math_sms = device_runtime->get_num_sms();
bool fp8_fast_accumulate = false;
// Operation description
cublasLtMatmulDesc_t desc;
DG_CUBLASLT_CHECK(cublasLtMatmulDescCreate(&desc, compute_type, scale_type));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms)));
if (a.scalar_type() == torch::kFloat8_e4m3fn)
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate)));
// Get cuBLASLt handle, workspace, and stream
const auto& handle = device_runtime->get_cublaslt_handle();
const auto& workspace = device_runtime->get_cublaslt_workspace();
const auto& workspace_bytes = workspace.nbytes();
const auto& stream = at::cuda::getCurrentCUDAStream();
// Algorithm selection
cublasLtMatmulPreference_t pref;
cublasLtMatmulHeuristicResult_t heuristic;
int num_heuristic_results = 0;
uint32_t reduction_scheme_mask = CUBLASLT_REDUCTION_SCHEME_NONE | CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE;
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&pref));
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_bytes, sizeof(workspace_bytes)));
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
&reduction_scheme_mask, sizeof(reduction_scheme_mask)));
DG_CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, layout_a, layout_b, layout_d, layout_d,
pref, 1, &heuristic, &num_heuristic_results));
DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM");
// Call: D = alpha * (A @ B) + beta * C
const float& alpha = 1.0, beta = accumulate ? 1.0 : 0.0;
DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle
desc, // Operation description
&alpha, // Alpha
b.data_ptr(), layout_a, // A
a.data_ptr(), layout_b, // B
&beta, // Beta
d.data_ptr(), layout_d, // C
d.data_ptr(), layout_d, // D
&heuristic.algo, // Algorithm
workspace.data_ptr(), workspace_bytes, // Workspace
stream)); // Stream
// Free memory
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceDestroy(pref));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_a));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_b));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_d));
DG_CUBLASLT_CHECK(cublasLtMatmulDescDestroy(desc));
}
static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs,
const std::optional<torch::Tensor>& acc,
const torch::Tensor& out,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major) {
const auto& trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto& trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T;
// Duplicate the accumulator if necessary
// TODO: remove this
if (acc.has_value()) {
if (acc->data_ptr() == out.data_ptr()) {
DG_HOST_ASSERT(acc->sizes() == out.sizes() and acc->strides() == out.strides());
} else {
out.copy_(acc.value());
}
}
// Matrix layouts
const auto& cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type());
const auto& cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type());
const auto& cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type());
const auto& layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0))
: get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1));
const auto& layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0))
: get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, acc.has_value());
}
static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
const int& b, const int& h, const int& r, const int& d) {
const auto& m = d, n = b, k = r;
const auto& trans_a = CUBLAS_OP_T;
const auto& trans_b = CUBLAS_OP_N;
// Matrix layouts
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0));
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
}
static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
const int& b, const int& h, const int& r, const int& d) {
const auto& m = r, n = b, k = d;
const auto& trans_a = CUBLAS_OP_N;
const auto& trans_b = CUBLAS_OP_N;
// Matrix layouts
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0));
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
}
} // namespace deep_gemm

View File

@@ -1,6 +1,8 @@
#include <pybind11/pybind11.h>
#include <torch/python.h>
#include "apis/attention.hpp"
#include "apis/einsum.hpp"
#include "apis/gemm.hpp"
#include "apis/layout.hpp"
#include "apis/runtime.hpp"
@@ -13,6 +15,8 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "DeepGEMM C++ library";
deep_gemm::attention::register_apis(m);
deep_gemm::einsum::register_apis(m);
deep_gemm::gemm::register_apis(m);
deep_gemm::layout::register_apis(m);
deep_gemm::runtime::register_apis(m);

View File

@@ -1,5 +1,6 @@
#pragma once
#include <cublasLt.h>
#include <exception>
#include <string>
#include <sstream>
@@ -72,4 +73,16 @@ do { \
} while (0)
#endif
#ifndef DG_CUBLASLT_CHECK
#define DG_CUBLASLT_CHECK(cmd) \
do { \
const auto& e = (cmd); \
if (e != CUBLAS_STATUS_SUCCESS) { \
std::ostringstream ss; \
ss << static_cast<int>(e) << " (" << cublasGetStatusString(e) << ")"; \
throw DGException("cuBLASLt", __FILE__, __LINE__, ss.str()); \
} \
} while (0)
#endif
} // namespace deep_gemm

View File

@@ -1,11 +1,15 @@
#pragma once
#include <array>
#include <filesystem>
#include <functional>
#include <random>
#include <string>
#include <memory>
#include <unistd.h>
#include "exception.hpp"
#include "format.hpp"
namespace deep_gemm {
@@ -65,7 +69,10 @@ static std::filesystem::path make_dirs(const std::filesystem::path& path) {
// OK if existed
std::error_code capture;
const bool& created = std::filesystem::create_directories(path, capture);
DG_HOST_ASSERT(created or capture.value() == 0);
if (not (created or capture.value() == 0)) {
DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}",
path.c_str(), created, capture.value()));
}
if (created and get_env<int>("DG_JIT_DEBUG"))
printf("Create directory: %s\n", path.c_str());
return path;

View File

@@ -25,15 +25,22 @@ from deep_gemm_cpp import (
# FP8 GEMMs
fp8_gemm_nt, fp8_gemm_nn,
fp8_gemm_tn, fp8_gemm_tt,
fp8_gemm_nt_skip_head_mid,
m_grouped_fp8_gemm_nt_contiguous,
m_grouped_fp8_gemm_nn_contiguous,
m_grouped_fp8_gemm_nt_masked,
k_grouped_fp8_gemm_nt_contiguous,
k_grouped_fp8_gemm_tn_contiguous,
# BF16 GEMMs
bf16_gemm_nt, bf16_gemm_nn,
bf16_gemm_tn, bf16_gemm_tt,
m_grouped_bf16_gemm_nt_contiguous,
m_grouped_bf16_gemm_nt_masked,
# cuBLASLt GEMMs
cublaslt_gemm_nt, cublaslt_gemm_nn,
cublaslt_gemm_tn, cublaslt_gemm_tt,
# Einsum kernels
einsum,
# Layout kernels
transform_sf_into_required_layout
)

View File

@@ -0,0 +1,27 @@
#pragma once
#include <deep_gemm/common/types.hpp>
#include <deep_gemm/common/utils.cuh>
namespace deep_gemm {
struct EpilogueIdentity {
template <uint32_t STORE_BLOCK_N>
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
return n_idx;
}
};
template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
struct EpilogueHeadSplits: EpilogueIdentity {
template <uint32_t STORE_BLOCK_N>
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0
and kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@@ -0,0 +1,44 @@
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda/std/cstdint>
#include <cuda/std/utility>
#include <deep_gemm/common/utils.cuh>
// Operation functors
template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
template <typename T> struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } };
template <typename T> struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } };
// Unified reduction function
template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
__forceinline__ __device__ T warp_reduce(T value, Op op) {
DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
"Invalid number of lanes");
constexpr uint32_t mask = 0xffffffff;
if constexpr (kIntergroupReduce) {
if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
} else {
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
}
return value;
}
// Convenience aliases
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_sum(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
}

View File

@@ -22,7 +22,6 @@ static constexpr uint32_t get_num_1d_blocks_per_group() {
if (usage < min_usage)
min_usage = usage, num_best_blocks = candidate;
}
return num_best_blocks;
}
@@ -33,6 +32,7 @@ template <GemmType kGemmType,
uint32_t kNumGroups,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
uint32_t SF_K_ALIGNMENT = 512u, // for k-grouped GEMM only: 128 (SM90 float SF) or 512 (SM100 UE8M0 SF)
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
struct Scheduler {
int current_iter = -1;
@@ -48,30 +48,40 @@ struct Scheduler {
// For grouped GEMM
int* grouped_layout;
uint32_t current_group_idx;
uint32_t current_group_idx = 0;
// Only used for masked layout
uint32_t current_m_cumsum;
uint32_t current_m_cumsum = 0;
// Only used for k-grouped layout
uint32_t current_shape_k, current_num_valid_groups, current_k_cumsum, current_sf_k_cumsum;
uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
uint32_t next_group_idx, next_shape_k;
// Only used for k-grouped gemm
__device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const {
for (; group_idx < kNumGroups; ++ group_idx) {
shape_k = __ldg(grouped_layout + group_idx);
if (shape_k > 0)
break;
}
}
// ReSharper disable once CppPossiblyUninitializedMember
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n,
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k,
int* grouped_layout = nullptr) {
num_m_blocks = ceil_div(shape_m, BLOCK_M);
num_n_blocks = ceil_div(shape_n, BLOCK_N);
current_shape_k = shape_k;
if constexpr (kGemmType == GemmType::Normal) {
num_blocks = num_m_blocks * num_n_blocks;
} else if (kGemmType == GemmType::MGroupedContiguous) {
num_blocks = num_m_blocks * num_n_blocks;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::MGroupedMasked) {
current_group_idx = current_m_cumsum = 0;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::KGroupedContiguous) {
current_group_idx = current_num_valid_groups = 0;
current_k_cumsum = current_sf_k_cumsum = 0;
current_shape_k = __ldg(grouped_layout + current_group_idx);
this->grouped_layout = grouped_layout;
get_next_k_group(current_group_idx, current_shape_k);
next_group_idx = current_group_idx + 1;
get_next_k_group(next_group_idx, next_shape_k);
}
}
@@ -165,17 +175,17 @@ struct Scheduler {
return false;
// Within current group
if (current_shape_k > 0 and next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
break;
// Move to check the next group
if (current_shape_k > 0) {
current_k_cumsum += current_shape_k;
current_sf_k_cumsum += ceil_div(current_shape_k, 512u);
current_num_valid_groups ++;
}
if ((++ current_group_idx) != kNumGroups)
current_shape_k = __ldg(grouped_layout + current_group_idx);
current_k_cumsum += current_shape_k;
current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT);
current_num_valid_groups ++;
current_group_idx = next_group_idx ++;
current_shape_k = next_shape_k;
get_next_k_group(next_group_idx, next_shape_k);
}
get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx);
@@ -197,7 +207,7 @@ struct Scheduler {
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
if (num_blocks_in_group == 1)
return false;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked) {
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::KGroupedContiguous) {
return true;
} else {
DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");

View File

@@ -79,12 +79,24 @@ void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_p
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
}
__device__ __forceinline__
static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) {
return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16;
}
// ReSharper disable once CppNotAllPathsReturnValue
template <uint32_t kSwizzleMode>
template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, bool kUseBase32, typename dtype_t>
constexpr static cute::UMMA::LayoutType to_umma_layout_type() {
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
kSwizzleMode == 32 or kSwizzleMode == 64 or
kSwizzleMode == 128, "Invalid swizzling mode");
// A special case
if constexpr ((cute::is_same_v<dtype_t, float> and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) {
DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base");
return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B;
}
// Normal cases
if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE;
if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE;
if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B;
@@ -104,10 +116,12 @@ uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, cons
return base + (((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, bool kUseBase32 = false, typename dtype_t>
__device__ __forceinline__
cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
const uint32_t stride_k = get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
const auto& layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
const auto& num_non_contiguous = 128 / get_atom_base(layout_type);
if constexpr (kMajorMode == cute::UMMA::Major::K) {
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
@@ -115,9 +129,9 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
// {SBO, LBO} means the byte stride between atoms on {MN, K}
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
const uint32_t stride_byte_offset = 8 * BLOCK_K * sizeof(dtype_t);
const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
const uint32_t leading_byte_offset = 0;
return make_smem_desc(to_umma_layout_type<kSwizzleMode>(),
return make_smem_desc(layout_type,
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
stride_byte_offset, leading_byte_offset);
} else {
@@ -132,11 +146,11 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
uint32_t stride_byte_offset = 8 * BLOCK_MN_ATOM * sizeof(dtype_t);
uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
if constexpr (kSwizzleMode == 16)
swap(stride_byte_offset, leading_byte_offset);
return make_smem_desc(to_umma_layout_type<kSwizzleMode>(),
return make_smem_desc(layout_type,
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
stride_byte_offset, leading_byte_offset);
}
@@ -166,4 +180,81 @@ __device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
// UMMA versions with relaxed assertions
struct SM100_MMA_F16BF16_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
struct SM100_MMA_F16BF16_2x1SM_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
struct SM100_MMA_MXF8F6F4_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
struct SM100_MMA_MXF8F6F4_2x1SM_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
} // namespace `deep_gemm::sm100`

View File

@@ -1,8 +1,12 @@
#pragma once
#include <cute/arch/copy_sm90_tma.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>
#include <deep_gemm/common/utils.cuh>
namespace deep_gemm::sm90 {
template <int N_, typename MMA>
@@ -29,6 +33,7 @@ struct FP8MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN();
@@ -93,6 +98,7 @@ struct BF16MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS<Major::K, Major::K>();
if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS<Major::K, Major::K>();
if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS<Major::K, Major::K>();
if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS<Major::K, Major::K>();
@@ -144,6 +150,24 @@ struct SM90_U32x2_STSM_N {
}
};
struct SM90_U32x2_LDSM_N {
__device__ __forceinline__ static void
copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) {
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst_0), "=r"(dst_1)
: "l"(smem_src));
}
};
struct SM90_U32x4_LDSM_N {
__device__ __forceinline__ static void
copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3)
: "l"(smem_src));
}
};
__forceinline__ __device__ void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
}
@@ -223,4 +247,37 @@ tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
}
}
__device__ __forceinline__ void
tma_3d_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& crd_2) {
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
cute::SM90_TMA_LOAD_3D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1, crd_2);
}
// Tensormap related
__device__ __forceinline__ void tensor_map_release_cta() {
asm volatile ("fence.proxy.tensormap::generic.release.cta;");
}
__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) {
auto gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory");
}
__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) {
auto smem_int_desc = static_cast<uint32_t>(__cvta_generic_to_shared(smem_desc));
const auto new_int64_addr = reinterpret_cast<uint64_t>(new_addr);
asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr));
}
__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) {
auto smem_int_desc = __cvta_generic_to_shared(smem_desc);
asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim));
#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 5)))
asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride));
#else
DG_STATIC_ASSERT(false, "Invalid CUDA version")
#endif
}
} // namespace `deep_gemm::sm90`

View File

@@ -104,6 +104,12 @@ __device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) {
return ret;
}
__device__ __forceinline__ float2 ld_shared(const float2* ptr) {
float2 ret;
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float4 ld_shared(const float4* ptr) {
float4 ret;
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr));
@@ -126,10 +132,18 @@ __device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
}
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y));
}
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) {
asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(ptr), "r"(x), "r"(y));
}
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(ptr), "r"(x), "r"(y), "r"(z), "r"(w));
}

View File

@@ -17,7 +17,7 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumStages,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
@@ -84,8 +84,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == 0) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
@@ -93,35 +92,31 @@ sm100_bf16_gemm_impl(int* grouped_layout,
cute::prefetch_tma_descriptor(&tensor_map_c);
}
// Data on shared memory (layout as ordered below)
cd_dtype_t* smem_cd[kNumTMAStoreStages];
cutlass::bfloat16_t* smem_a[kNumStages];
cutlass::bfloat16_t* smem_b[kNumStages];
// Fill D/A/B pointers
#pragma unroll
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
}
// D/A/B shared memory
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (threadIdx.x == 0) {
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive only at the leader CTA
@@ -136,11 +131,12 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// Arrive only at the leader CTA
tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads);
}
if constexpr (kTensorCoreUtilControl < 100)
tensor_core_full_barrier->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
} else if (warp_idx == 2) {
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
@@ -148,100 +144,69 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
uint32_t phase = 0;
auto launch_k_iterations = [&](const auto& func) {
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K);
const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages;
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// TODO: refactor here
if (num_last_stages == 0) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1)
func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages);
} else {
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1)
func(k_iter, DivisibleK{}, false, num_last_stages);
func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1;
}
};
auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) {
DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2,
"Too many epilogue stages, please modify the Python heuristic as well");
accum_stage_idx == 0 ? func(0) : func(1);
// Flip phases only if reach the next first stage
stage_idx = (stage_idx + 1) % kNumStages;
phase ^= stage_idx == 0;
};
// Dispatch warps into different roles
if (warp_idx == 0) {
if (warp_idx == 0 and cute::elect_one_sync()) {
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
#pragma unroll
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait(phase ^ 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_block_idx = k_iter * kNumStages + s;
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Issue TMAs
if (cute::elect_one_sync()) {
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
}
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
if (is_leader_cta and cute::elect_one_sync())
full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
if (not is_leader_cta and cute::elect_one_sync())
full_barriers[s]->arrive(0u);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait(phase ^ 1);
if (is_leader_cta and cute::elect_one_sync())
full_barriers[s]->arrive();
if (not is_leader_cta and cute::elect_one_sync())
full_barriers[s]->arrive(0u);
// Issue TMAs
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
if (is_leader_cta) {
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
} else {
full_barriers[stage_idx]->arrive(0u);
}
});
}
}
} else if (warp_idx == 1 and is_leader_cta) {
// MMA issue warp
@@ -268,88 +233,89 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
// Wait tensor memory empty barrier arrival
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
// Wait tensor memory empty barrier arrival
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
tcgen05_after_thread_sync();
// UMMA and empty barrier arrival alias
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
};
// Launch MMAs
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA arrival
full_barriers[stage_idx]->wait(phase);
tcgen05_after_thread_sync();
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) {
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[s]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
};
// Launch MMAs
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
// Issue UMMA in the leader CTA
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_F16BF16_SS, SM100_MMA_F16BF16_2x1SM_SS>;
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA arrival
full_barriers[s]->wait(phase);
tcgen05_after_thread_sync();
// Let tensor cores relax for lower possibility of frequency drop
DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
if constexpr (kTensorCoreUtilControl < 100) {
constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull;
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
const auto& start_clock = clock64();
if (cute::elect_one_sync())
while (clock64() - start_clock < kNumDummyCycles) {}
__syncwarp();
}
// Issue UMMA in the leader CTA
using cute_mma_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_MMA_F16BF16_SS <cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>,
cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>>;
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s);
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s);
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
cute_mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_iter > 0 or s > 0 or k > 0,
runtime_instr_desc);
}
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_block_idx > 0 or k > 0,
runtime_instr_desc);
}
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1);
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait(phase);
empty_barrier_arrive(s, false);
}
});
});
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
// Let tensor cores relax for lower possibility of frequency drop
DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
if constexpr (kTensorCoreUtilControl < 100) {
// For utilization control
umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
// Wait for last UMMA to be done
tensor_core_full_barrier->wait(tensor_core_phase);
tensor_core_phase ^= 1;
// Sleep for certain cycles
constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull;
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
const auto& start_clock = clock64();
if (cute::elect_one_sync())
while (clock64() - start_clock < kNumDummyCycles) {}
__syncwarp();
}
}
}
// To safely deconstruct barriers, we need another round of waits
const auto& iter_idx = scheduler.current_iter - 1;
if (kNumMulticast > 1 and iter_idx >= 0) {
const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
}
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
// Epilogue warp groups
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
@@ -363,129 +329,114 @@ sm100_bf16_gemm_impl(int* grouped_layout,
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Share store pipeline between blocks
uint32_t tma_stage_idx = 0;
auto advance_store_pipeline = [&]() {
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
};
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
// Flush TMA stores
// NOTES: for the first store, we have to flush all previous TMA,
// as we don't share pipeline stages between two blocks
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Iterate over M waves
// Iterate over M waves
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
// Wait shared memory to be released
if (epilogue_warp_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
// The pipeline stage
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Store into shared memory
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s) {
// Wait shared memory to be released
const uint32_t iter_idx = w * kNumStores + s;
if (iter_idx >= kNumTMAStoreStages) {
if (epilogue_thread_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
}
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// The pipeline stage
const auto tma_stage_idx = iter_idx % kNumTMAStoreStages;
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
__syncwarp();
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
if (epilogue_thread_idx == 0) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
}
});
}
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
__syncwarp();
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
}
}
}
}
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
if (epilogue_warp_idx == 1)
Allocator().free(0, kNumTmemCols);
}
// To safely deconstruct all barriers, we need a cluster sync
// TODO: optimize it by another round of barrier waits
if constexpr (kNumMulticast > 1)
cute::cluster_sync();
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");

View File

@@ -0,0 +1,265 @@
#pragma once
#include <cute/arch/cluster_sm90.hpp>
#include <cute/util/type_traits.hpp>
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm100;
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSplitFactor,
uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages, uint32_t kNumThreads>
__global__ void __launch_bounds__(kNumThreads, 1)
sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Configs
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t kNumTMAStoreStages = 2;
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// Shared memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
// Real tensor memory size and offsets
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_N>();
// Fill D/A/B
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 2 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(1);
}
tmem_full_barrier->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (warp_idx == 2) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
// Block indices
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
if (warp_idx == 0) {
// TMA load warp
for (uint32_t s = 0; s < num_total_stages; ++ s) {
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
uint32_t m_idx = BLOCK_M * m_block_idx;
uint32_t n_idx = BLOCK_N * n_block_idx;
uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
uint32_t k_idx = sk_idx % SHAPE_K;
uint32_t s_idx = sk_idx / SHAPE_K;
// Issue TMAs
if (cute::elect_one_sync()) {
tma_copy<BLOCK_K, BLOCK_M, kSwizzleABMode, 1>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
tma_copy<BLOCK_K, BLOCK_N, kSwizzleABMode, 1>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
}
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
if (cute::elect_one_sync())
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
}
} else if (warp_idx == 1) {
// MMA issue warp
// NOTES: only the leader CTA will do this
// Make instruction descriptor
constexpr uint32_t UMMA_M = LAYOUT_AD_M;
constexpr uint32_t UMMA_N = BLOCK_N;
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto a_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
auto b_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Wait tensor memory empty barrier arrival
tcgen05_after_thread_sync();
// Launch MMAs
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
tcgen05_after_thread_sync();
// Issue UMMA in the leader CTA
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx);
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx);
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
a_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(a_desc_base_lo, 0, k * UMMA_K);
b_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
}
}
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
}
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
}
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
// i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
if (warp_idx == 2)
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Wait UMMA arrival
tmem_full_barrier->wait(0);
tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s) {
// Wait shared memory to be released
if (s >= kNumTMAStoreStages) {
if (warp_idx == 0 and cute::elect_one_sync())
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(kNumThreads).sync();
}
// The pipeline stage
const auto tma_stage_idx = s % kNumTMAStoreStages;
const auto m_idx = m_block_idx * BLOCK_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumThreads).sync();
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
}
}
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is doing TMA stores
if (warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
#endif
}
}

View File

@@ -4,6 +4,7 @@
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/epilogue_utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
@@ -17,11 +18,12 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumStages,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t>
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
typename epilogue_type_t>
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
@@ -96,8 +98,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == 0) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
@@ -107,30 +108,25 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
cute::prefetch_tma_descriptor(&tensor_map_c);
}
// Data on shared memory (layout as ordered below)
cd_dtype_t* smem_cd[kNumTMAStoreStages];
cutlass::float_e4m3_t* smem_a[kNumStages];
cutlass::float_e4m3_t* smem_b[kNumStages];
uint32_t* smem_sfa[kNumStages];
uint32_t* smem_sfb[kNumStages];
// D/A/B shared memory
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// Fill D/A/B pointers
#pragma unroll
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
}
// Fill SFA/SFB
// SFA/SFB shared memory
auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_sfa[i] = reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
smem_sfb[i] = reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
}
auto smem_sfa = PatternVisitor([=](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
});
auto smem_sfb = PatternVisitor([=](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer +
@@ -148,7 +144,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (threadIdx.x == 0) {
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive at all CTAs
@@ -166,9 +162,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
} else if (warp_idx == 2) {
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
@@ -176,108 +171,75 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
uint32_t phase = 0;
auto launch_k_iterations = [&](const auto& func) {
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K);
const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages;
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// TODO: refactor here
if (num_last_stages == 0) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1)
func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages);
} else {
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1)
func(k_iter, DivisibleK{}, false, num_last_stages);
func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1;
}
};
auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) {
DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2,
"Too many epilogue stages, please modify the Python heuristic as well");
accum_stage_idx == 0 ? func(0) : func(1);
// Flip phases only if reach the next first stage
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
phase ^= stage_idx == 0;
};
// Dispatch warps into different roles
if (warp_idx == 0) {
if (warp_idx == 0 and cute::elect_one_sync()) {
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait(phase ^ 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_block_idx = k_iter * kNumStages + s;
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Issue TMAs
if (cute::elect_one_sync()) {
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
}
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
// Issue SFA and SFB TMAs at certain stages
// No swizzling, so one TMA for one SF is enough
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
tma_copy<BLOCK_M, 1, 0, 1>(&tensor_map_sfa, full_barriers[s], smem_sfa[s], m_block_idx * BLOCK_M,
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
tma_copy<BLOCK_N, 1, 0, 1>(&tensor_map_sfb, full_barriers[s], smem_sfb[s], n_block_idx * BLOCK_N,
scheduler.template get_global_idx<true, KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
}
// Arrive at full barriers
if (cute::elect_one_sync())
full_barriers[s]->arrive_and_expect_tx(num_arrival_bytes);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait(phase ^ 1);
if (cute::elect_one_sync())
full_barriers[s]->arrive();
// Issue TMAs
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx);
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
// Issue SFA and SFB TMAs at certain stages
// No swizzling, so one TMA for one SF is enough
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
tma_copy<BLOCK_M, 1, 0, 1>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M,
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
tma_copy<BLOCK_N, 1, 0, 1>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N,
scheduler.template get_global_idx<true, KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
}
});
// Arrive at full barriers
full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
}
}
} else if (warp_idx == 1 and is_leader_cta) {
// MMA issue warp
@@ -307,101 +269,93 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
// Wait tensor memory empty barrier arrival
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
// Wait tensor memory empty barrier arrival
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
tcgen05_after_thread_sync();
// Empty barrier arrival
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
};
// Launch MMAs
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA and SF-transpose arrival
with_sf_full_barriers[stage_idx]->wait(phase);
tcgen05_after_thread_sync();
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) {
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[s]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
};
// Launch MMAs
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
// Do SF copy at certain stages
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
// SFA and SFB copy
// TODO: process shared memory descriptor by addition
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA and SF-transpose arrival
with_sf_full_barriers[s]->wait(phase);
tcgen05_after_thread_sync();
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
}
}
__syncwarp();
// Do SF copy at certain stages
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
// SFA and SFB copy
// TODO: process shared memory descriptor by addition
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[s] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[s] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
}
}
__syncwarp();
// Issue UMMA in the leader CTA
using cute_mma_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_MMA_MXF8F6F4_SS <cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>,
cute::SM100_MMA_MXF8F6F4_2x1SM_SS<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>>;
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s);
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s);
// Issue UMMA in the leader CTA
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_MXF8F6F4_SS, SM100_MMA_MXF8F6F4_2x1SM_SS>;
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
cute_mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_iter > 0 or s > 0 or k > 0,
runtime_instr_desc,
kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
kTmemStartColOfSFB);
}
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_block_idx > 0 or k > 0,
runtime_instr_desc,
kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
kTmemStartColOfSFB);
}
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1);
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
with_sf_full_barriers[s]->wait(phase);
empty_barrier_arrive(s, false);
}
});
});
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
}
}
// To safely deconstruct barriers, we need another round of waits
const auto& iter_idx = scheduler.current_iter - 1;
if (kNumMulticast > 1 and iter_idx >= 0) {
const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
}
} else if (warp_idx == 2) {
// UTCCP transposer
@@ -418,43 +372,30 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
};
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA arrival
full_barriers[stage_idx]->wait(phase);
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA arrival
full_barriers[s]->wait(phase);
// Transpose for UTCCP at certain stages
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfa[s] + i * kNumUTCCPAlignedElems);
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfb[s] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
// Arrive
with_sf_full_barriers[s]->arrive(0u);
// Transpose for UTCCP at certain stages
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait(phase);
with_sf_full_barriers[s]->arrive(0u);
}
});
// Arrive
with_sf_full_barriers[stage_idx]->arrive(0u);
}
}
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
// Epilogue warp groups
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
@@ -468,129 +409,113 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Share store pipeline between blocks
uint32_t tma_stage_idx = 0;
auto advance_store_pipeline = [&]() {
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
};
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
// Flush TMA stores
// NOTES: for the first store, we have to flush all previous TMA,
// as we don't share pipeline stages between two blocks
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Iterate over M waves
// Iterate over M waves
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
// Wait shared memory to be released
if (epilogue_warp_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
// The pipeline stage
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
// Store into shared memory
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s) {
// Wait shared memory to be released
const uint32_t iter_idx = w * kNumStores + s;
if (iter_idx >= kNumTMAStoreStages) {
if (epilogue_thread_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
}
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// The pipeline stage
const auto tma_stage_idx = iter_idx % kNumTMAStoreStages;
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
__syncwarp();
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
if (epilogue_thread_idx == 0) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
}
});
}
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
}
}
}
}
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
if (epilogue_warp_idx == 1)
Allocator().free(0, kNumTmemCols);
}
// To safely deconstruct all barriers, we need a cluster sync
// TODO: optimize it by another round of barrier waits
if constexpr (kNumMulticast > 1)
cute::cluster_sync();
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");

View File

@@ -5,6 +5,7 @@
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/epilogue_utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
@@ -22,7 +23,8 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, typename cd_dtype_t>
GemmType kGemmType, typename cd_dtype_t,
typename epilogue_type_t>
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
@@ -88,8 +90,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == 0) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
@@ -133,7 +134,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (threadIdx.x == 0) {
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive at all CTAs
@@ -149,9 +150,8 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
} else if (warp_idx == 2) {
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
@@ -174,7 +174,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// Register configurations
constexpr uint32_t kNumNonEpilogueRegisters = 64;
@@ -435,7 +435,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// as we don't share pipeline stages between two blocks
if (epilogue_thread_idx_in_warpgroup == 0)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx);
// Write shared memory
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
@@ -449,13 +449,13 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
if (s >= kNumTMAStoreStages) {
if (epilogue_thread_idx_in_warpgroup == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx);
}
// The pipeline stage
const auto tma_stage_idx = s % kNumTMAStoreStages;
const auto m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_m, BLOCK_M, m_block_idx);
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
const auto local_smem_cd = smem_cd[tma_stage_idx] + epilogue_warpgroup_idx * STORE_BLOCK_M * STORE_BLOCK_N;
// Store into shared memory
@@ -502,7 +502,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx);
if (epilogue_thread_idx_in_warpgroup == 0) {
cute::SM90_TMA_STORE_2D::copy(
&tensor_map_d, local_smem_cd,
@@ -512,10 +512,6 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
}
}
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
if (epilogue_thread_idx_in_warpgroup == 0)
cute::tma_store_wait<0>();
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
if (epilogue_warp_idx == 1)

View File

@@ -25,7 +25,8 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs, GemmType kGemmType>
uint32_t kNumSMs, GemmType kGemmType,
typename cd_dtype_t>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_bf16_gemm_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
@@ -44,7 +45,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Shared memory
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(cd_dtype_t);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
@@ -55,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
const uint32_t lane_idx = get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == kNumMathThreads) {
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
@@ -67,7 +68,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Data on shared memory
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
__nv_bfloat16* smem_a[kNumStages];
__nv_bfloat16* smem_b[kNumStages];
@@ -91,7 +92,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
}
// Initialize barriers
if (threadIdx.x == kNumMathThreads) {
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
@@ -99,7 +100,6 @@ sm90_bf16_gemm_impl(int* grouped_layout,
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
}
@@ -125,14 +125,14 @@ sm90_bf16_gemm_impl(int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) {
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto divisible_type) {
@@ -203,7 +203,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
}
};
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Launch MMAs
launch_k_iterations([&](uint32_t k_iter, auto divisible_type) {
@@ -237,11 +237,10 @@ sm90_bf16_gemm_impl(int* grouped_layout,
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(s);
}
// Notify barrier arrival
empty_barrier_arrive(s);
}
// Wait unaligned cases
@@ -256,7 +255,6 @@ sm90_bf16_gemm_impl(int* grouped_layout,
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
"Unaligned TMA store or too many TMA store instructions");
@@ -265,60 +263,76 @@ sm90_bf16_gemm_impl(int* grouped_layout,
// Wait last TMA store to be finished
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
if constexpr (std::is_same_v<cd_dtype_t, cutlass::bfloat16_t>) {
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// Swizzle or padding into the correct address
uint8_t* smem_ptr = nullptr;
if constexpr (kSwizzleDMode > 0) {
// Calculate the swizzling atom offset and in-atom offset
constexpr uint32_t kNumBankGroupBytes = 16;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// Swizzle or padding into the correct address
uint8_t* smem_ptr = nullptr;
if constexpr (kSwizzleDMode > 0) {
// Calculate the swizzling atom offset and in-atom offset
constexpr uint32_t kNumBankGroupBytes = 16;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
col ^= row % (kSwizzleDMode / 16);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
col ^= row % (kSwizzleDMode / 16);
// Add back into the base pointer
// NOTES: think twice before modifying this, as changes may affect the number of instructions
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
m_offset * kSwizzleDMode + // Wave offset
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
} else {
// No swizzling, just padding
// TODO: support more cases
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
// Add back into the base pointer
// NOTES: think twice before modifying this, as changes may affect the number of instructions
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
m_offset * kSwizzleDMode + // Wave offset
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
} else {
// No swizzling
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
}
// NOTES: only 16 lanes' addresses are used
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_ptr
);
}
}
}
else {
// Use `st.shared` if STSM is not available
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2);
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
}
// NOTES: only 16 lanes' addresses are used
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_ptr
);
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Use TMA store to write back to global memory
// TODO: compatible with FP32 output
constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked;
DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {

View File

@@ -0,0 +1,173 @@
#pragma once
#include <cute/arch/cluster_sm90.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSplitFactor,
uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
float *d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Types
using WGMMA = typename BF16MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Shared memory
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
// Configs
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_idx();
DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
// Fill shared memory pointers
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumMathThreads);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = 232;
// Block indices
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
// Persistently schedule over blocks
#pragma unroll
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
const uint32_t& k_idx = sk_idx % SHAPE_K;
const uint32_t& s_idx = sk_idx / SHAPE_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
float accum[WGMMA::kNumAccum] = {0};
// Launch MMAs
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrivals
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, 1);
}
warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
empty_barriers[stage_idx]->arrive();
}
const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
if (col + i * 8 >= SHAPE_N)
break;
if (row < SHAPE_M) {
atomicAdd(reinterpret_cast<float2*>(d + (row + 0) * SHAPE_N + col + i * 8),
make_float2(accum[i * 4 + 0], accum[i * 4 + 1]));
}
if (row + 8 < SHAPE_M) {
atomicAdd(reinterpret_cast<float2*>(d + (row + 8) * SHAPE_N + col + i * 8),
make_float2(accum[i * 4 + 2], accum[i * 4 + 3]));
}
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
}; // namespace deep_gemm

View File

@@ -1,3 +1,348 @@
#pragma once
// TODO: add implement
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumGroups,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, typename cd_dtype_t>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
int* grouped_layout,
cute::TmaDescriptor* tensor_map_buffer,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a_base,
const __grid_constant__ cute::TmaDescriptor tensor_map_b_base,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads");
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Shared memory
static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
// Configs
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = threadIdx.x % 32;
// Prefetch TMA descriptors at the very beginning
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a_base);
cute::prefetch_tma_descriptor(&tensor_map_b_base);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
cute::prefetch_tma_descriptor(&tensor_map_sfb);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Tensor maps on shared and global memory
auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * i);
});
auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * (2 + i));
});
auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; });
auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; });
// Data on shared memory
auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
});
auto smem_sfb = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
});
// Barriers on shared memory
constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
auto full_barriers = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
});
auto empty_barriers = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
});
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
// Load tensormap A/B to shared memory
if constexpr (kGemmType == GemmType::KGroupedContiguous) {
*smem_tensor_map_a[0] = tensor_map_a_base;
*smem_tensor_map_a[1] = tensor_map_a_base;
*smem_tensor_map_b[0] = tensor_map_b_base;
*smem_tensor_map_b[1] = tensor_map_b_base;
}
// Initialize barriers
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// Pipeline unroll control
constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages);
// Register reconfigurations (more math registers are needed with unrolling)
constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
// TMA and MMA pipeline
const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
};
uint32_t iter_idx = 0;
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base;
const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
uint32_t last_group_idx = kNumGroups, sum_k = 0;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
const uint32_t& m_idx = m_block_idx * BLOCK_M;
const uint32_t& n_idx = n_block_idx * BLOCK_N;
if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) {
const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1;
const uint32_t& next_stage_idx = stage_idx ^ 1;
last_group_idx = scheduler.current_group_idx;
// Prepare next tensor map
sum_k += scheduler.current_shape_k;
if (scheduler.next_group_idx < kNumGroups) {
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + sum_k * shape_m);
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + sum_k * shape_n);
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
*(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]);
*(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]);
tensor_map_release_cta();
}
// Get current tensor map
if (scheduler.current_num_valid_groups > 0) {
tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]);
tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]);
current_tensor_map_a = gmem_tensor_map_a[stage_idx];
current_tensor_map_b = gmem_tensor_map_b[stage_idx];
}
}
#pragma unroll kNumPipelineUnrolls
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
// Wait consumer release
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
empty_barriers[stage_idx]->wait(phase ^ 1);
// Issue TMA
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t& k_idx = k_block_idx * BLOCK_K;
const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier), smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
tma_copy(&tensor_map_sfb, reinterpret_cast<uint64_t*>(&full_barrier), smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
tma_copy(current_tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier), smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
tma_copy(current_tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier), smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
}
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s) {
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
empty_barriers[stage_idx]->wait(phase ^ 1);
}
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Accumulation for WGMMA or CUDA promotion
DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K);
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
float2 scales_b[WGMMA::kNumAccum / 4];
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
}
};
#pragma unroll kNumPipelineUnrolls
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
// Wait TMA arrivals
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
full_barriers[stage_idx]->wait(phase);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0);
auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1);
// Read B scales
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(stage_idx);
// Promote with scales
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
const float &scale_b_0 = scales_b[i].x;
const float &scale_b_1 = scales_b[i].y;
final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
}
}
// Flush previous stores
if (warp_idx % 4 == 0 and cute::elect_one_sync())
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
// Store to D shared memory
const auto& smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
const auto& smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
// Use TMA store to write back to global memory
if (warp_idx % 4 == 0 and cute::elect_one_sync()) {
cute::SM90_TMA_REDUCE_ADD_2D::copy(
&tensor_map_d, smem_d_0, n_block_idx * BLOCK_N,
current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0);
cute::tma_store_arrive();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -10,6 +10,7 @@
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include <deep_gemm/common/epilogue_utils.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
@@ -18,15 +19,15 @@ namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd>
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) {
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
if (num_former_iters == kNumFormerIters) {
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
func(cute::Int<kNumFormerIters>{});
return;
}
if constexpr (kNumFormerIters + kGap <= kEnd)
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
dispatch_num_former_iters<kNumFormerIters + kGap, kGap, kEnd>(num_former_iters, func);
}
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
@@ -36,7 +37,8 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs, GemmType kGemmType>
uint32_t kNumSMs, GemmType kGemmType,
typename epilogue_type_t>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
@@ -69,14 +71,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
const uint32_t& smem_sfb_size = align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == kNumMathThreads) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
@@ -90,35 +90,26 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Data on shared memory
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
__nv_fp8_e4m3* smem_a[kNumStages];
__nv_fp8_e4m3* smem_b[kNumStages];
float* smem_sfa[kNumStages];
float* smem_sfb;
// TMA Barrier for both divisible and non-divisible cases
Barrier* full_barriers[kNumStages];
Barrier* empty_barriers[kNumStages];
// Fill shared memory pointers
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_sfa[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE);
}
smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE));
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE);
});
auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE);
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i] = barrier_start_ptr + i;
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
}
auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
if (threadIdx.x == kNumMathThreads) {
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll
@@ -128,107 +119,72 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
struct SkipComputation {};
struct NotSkipComputation {};
auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) {
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
// NOTES: for too-many branches (> 5), we disable this optimization
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) {
if (skip_computation) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type);
} else if (shape_k % kFullKOfAllStages == 0) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
} else {
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
}
}, func, kShouldOptimize ? num_former_iters : 0);
};
// Register reconfigurations
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = 232;
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// Flip phases only if reach the next first stage
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
phase ^= stage_idx == 0;
};
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) {
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
// Issue TMA A
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t k_idx = k_block_idx * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a);
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier),
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_block_idx),
num_tma_multicast_a);
// Issue TMA A
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[s];
uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a);
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier),
smem_sfa[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_idx / BLOCK_K),
num_tma_multicast_a);
// Issue TMA B
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
}, false, 0);
// Issue TMA B
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
}
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
for (uint32_t i = 0; i < kNumStages; advance_pipeline(i))
empty_barriers[stage_idx]->wait(phase ^ 1);
}
}
} else {
@@ -239,6 +195,11 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
auto b_desc = make_smem_desc(smem_b[0], 1);
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
@@ -259,7 +220,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
st_shared(smem_sfb + i, __ldg(local_sfb + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Accumulation for WGMMA or CUDA promotion
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2);
@@ -267,90 +228,96 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s) {
auto empty_barrier_arrive = [&]() {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void();
} else {
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void();
}
};
// Launch MMAs
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) {
constexpr bool kSkipComputation = cute::is_same_v<decltype(skip_type), SkipComputation>;
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages);
// Skip useless computations
if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
// The compiler must know the dynamic variable `num_former_iters`'s real value
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_sfb + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_sfb + k_iter * kNumStages + s + shape_k_scales);
// Dispatch `num_former_iters` and launch MMAs
dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
#pragma unroll 8
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
// Read B scales
float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales);
// TODO: remove some useless computation for unaligned Ms
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Wait TMA arrivals
full_barriers[stage_idx]->wait(phase);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_sfa[s] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_sfa[s] + r_1 + m_offset);
// TODO: remove some useless computation for unaligned Ms
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1 + m_offset);
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(s);
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
WGMMA::wgmma(a_desc, b_desc, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive();
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
}
}
}
// Wait unaligned cases
});
} else {
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
empty_barrier_arrive(s);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
full_barriers[stage_idx]->wait(phase);
empty_barrier_arrive();
}
}, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters);
}
// TMA checks
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
@@ -364,7 +331,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Wait last TMA store to be finished
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
@@ -413,7 +380,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Use TMA store to write back to global memory
// TODO: compatible with FP32 output
@@ -423,7 +390,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
n_block_idx * BLOCK_N + in_block_n_offset,
epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset),
scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
}

View File

@@ -91,7 +91,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) if not using_nsys else None
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
with profiler:
for i in range(2):
@@ -112,10 +112,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
is_tuple = isinstance(kernel_names, tuple)
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
if not with_multiple_kernels:
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
@@ -136,6 +135,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
total_num += int(num_str)
break
kernel_times.append(total_time / total_num)
kernel_times.append(total_time / total_num if total_num > 0 else 0)
return tuple(kernel_times) if is_tuple else kernel_times[0]

View File

@@ -16,13 +16,16 @@ def ceil_to_ue8m0(x: torch.Tensor):
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
assert x.dim() == 2
m, n = x.shape
x_view = x.view(m, -1, 128)
padded_n = align(n, 128)
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
x_padded[:, :n] = x
x_view = x_padded.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -54,4 +57,4 @@ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled, sf.squeeze()
return x_scaled, sf.squeeze()

View File

@@ -1,7 +1,7 @@
import enum
import random
import torch
from typing import Generator, Tuple, List
from typing import Generator, List
from deep_gemm.utils import (
align, ceil_div,
@@ -11,7 +11,6 @@ from deep_gemm.utils import (
class KernelType(enum.Enum):
# For SM100 GEMMs
Kernel1D1D = 0
Kernel1D2D = 1
KernelNoSF = 2
@@ -48,62 +47,87 @@ def get_ue8m0_usage(kernel_type: KernelType) -> bool:
return kernel_type.is_1d1d()
def get_kernel_types(use_bf16: bool = False) -> tuple:
if use_bf16:
def get_kernel_types(dtype: torch.dtype) -> tuple:
if dtype == torch.bfloat16:
return (KernelType.KernelNoSF, )
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)
# TODO: SM100 1D2D kernels are going to be deprecated
# But if you want to test it, please use:
# `(KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)`
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, )
def get_out_dtype() -> tuple:
return (torch.bfloat16, ) if get_arch_major() == 9 else (torch.bfloat16, torch.float)
def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator:
for major_a in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor):
for major_b in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor):
if major_a.is_mn_major() and not allow_a_mn_major:
continue
if major_b.is_mn_major() and not allow_b_mn_major:
continue
yield major_a, major_b
def get_major_ab(freeze_a: bool) -> tuple:
# TODO: test other major-ness for SM90 BF16 GEMMs
if get_arch_major() == 9:
return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), )
if freeze_a:
return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor)
return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor), \
(MajorTypeAB.MNMajor, MajorTypeAB.KMajor), (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
def enumerate_normal(dtype: torch.dtype) -> Generator:
assert dtype in (torch.float8_e4m3fn, torch.bfloat16)
fp32_output_nk = [(256, 7168), (129280, 7168)]
bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]
m_fwd_list, m_bwd_list = [128, 4096], [4096, ]
nk_list = bf16_output_nk
# Only BF16 GEMM needs FP32 outputs
if dtype == torch.bfloat16:
nk_list += fp32_output_nk
for kernel_type in get_kernel_types(dtype):
# Forward
for m in m_fwd_list:
for n, k in nk_list:
out_dtype = torch.float if (n, k) in fp32_output_nk else torch.bfloat16
yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype
# TODO: support BF16 SM90 MN-major kernels
if dtype == torch.bfloat16 and get_arch_major() == 9:
continue
# Backward
for m in m_bwd_list:
for n, k in nk_list:
override_major = MajorTypeAB.MNMajor
override_kernel_type = kernel_type
if get_arch_major() == 9 and dtype == torch.float8_e4m3fn:
override_major = MajorTypeAB.KMajor
override_kernel_type = KernelType.Kernel1D1D
yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad
yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad
yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad
def enumerate_normal(use_bf16: bool = False) -> Generator:
for kernel_type in get_kernel_types(use_bf16):
for m in (128, 4096):
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]:
for major_a, major_b in get_major_ab(False):
for out_dtype in get_out_dtype():
for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True):
yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype
def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator:
for kernel_type in get_kernel_types(use_bf16):
def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator:
for kernel_type in get_kernel_types(dtype):
for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)):
for major_a, major_b in get_major_ab(True):
for major_a, major_b in get_major_ab(False, get_arch_major() > 9):
yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b
def enumerate_m_grouped_masked() -> Generator:
def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator:
max_m = 4096
for kernel_type in get_kernel_types():
for kernel_type in get_kernel_types(dtype):
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
for n, k in ((4096, 7168), (7168, 2048), ):
yield kernel_type, num_groups, max_m, m, n, k
def enumerate_k_grouped_contiguous():
# TODO: support SM90 kernels
if get_arch_major() == 9:
return []
# Only K-major is supported for SM90
major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 \
else (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
# Must with FP32 accumulation and 1D1D kernels
for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64
( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32
(16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16
ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
yield num_groups, m, n, ks, expected_k_per_group
yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group
def enumerate_sf_layout():
@@ -134,6 +158,7 @@ def enumerate_transpose():
def generate_normal(m: int, n: int, k: int,
major_a: MajorTypeAB, major_b: MajorTypeAB,
accumulate: bool, out_dtype: torch.dtype,
kernel_type: KernelType,
use_ue8m0: bool = False, use_bf16: bool = False):
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
@@ -147,7 +172,9 @@ def generate_normal(m: int, n: int, k: int,
b = b if major_b.is_k_major() else b.T.contiguous().T
return a, b, c, d, ref_d
a_fp8, b_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0), per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
b_fp8 = per_token_cast_to_fp8(b, use_ue8m0=use_ue8m0) if kernel_type.is_1d1d() and accumulate \
else per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1])
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1])
return a_fp8, b_fp8, c, d, ref_d
@@ -214,7 +241,7 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group:
return a_fp8, b_fp8, masked_m, d, ref_d
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool):
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], use_ue8m0: bool):
assert get_mk_alignment_for_contiguous_layout() % 128 == 0
k = sum(ks)
@@ -232,4 +259,20 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int]
a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0)
b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0)
# Transpose for K Major A/B
if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor):
a, sfa = a_fp8
b, sfb = b_fp8
new_a = torch.empty((sum(ks) * m, ), dtype=a.dtype, device=a.device)
new_b = torch.empty((sum(ks) * n, ), dtype=b.dtype, device=b.device)
prefix = 0
for K in ks:
new_a[prefix * m : (prefix + K) * m] = a[prefix : prefix + K, ].T.flatten()
new_b[prefix * n : (prefix + K) * n] = b[prefix : prefix + K, ].T.flatten()
prefix += K
a_fp8, b_fp8 = (new_a, sfa.T), (new_b, sfb.T)
else:
assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
return k, a_fp8, b_fp8, c, d, ref_d

64
tests/test_attention.py Normal file
View File

@@ -0,0 +1,64 @@
import random
import torch
from typing import Tuple
import deep_gemm
from deep_gemm.testing import bench_kineto, calc_diff, count_bytes
from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8
from generators import get_arch_major, generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB
def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]):
left, mid, right = head_splits
m, n = d.shape
assert n % (left + right) == 0
num_heads = n // (left + right)
# Split and insert padding tensor
d = d.view(m, num_heads, -1)
d_left = d[:, :, :left]
d_right = d[:, :, -right:]
d_mid = torch.zeros((m, num_heads, mid), dtype=d.dtype, device=d.device)
return torch.cat([d_left, d_mid, d_right], dim=2).view(m, -1)
def test_gemm_skip_head_mid() -> None:
print('Testing GEMM skip head mid:')
head_splits = (128, 64, 128)
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
out_dtype, accumulate = torch.bfloat16, False
for kernel_type in get_kernel_types(dtype=torch.float8_e4m3fn):
for m in (128, 4096):
for n, k in [(32768, 512), (8192, 512)]:
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
a, b, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
d = apply_skip_head_mid(d, head_splits)
ref_d = apply_skip_head_mid(ref_d, head_splits)
deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast)
diff = calc_diff(d, ref_d)
assert diff < 0.001, f'{m=}, {n=}, {k=}, {kernel_opt}, {diff:.5f}'
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast),
'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)
random.seed(0)
test_gemm_skip_head_mid()

View File

@@ -7,6 +7,7 @@ from deep_gemm.testing import (
calc_diff, count_bytes
)
from generators import (
get_arch_major,
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, generate_normal,
generate_m_grouped_contiguous, generate_m_grouped_masked
)
@@ -14,14 +15,18 @@ from generators import (
def test_gemm() -> None:
print('Testing GEMM:')
for _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(use_bf16=True):
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16):
# TODO: support accumulation for SM90 BF16 GEMM
if get_arch_major() == 9 and accumulate:
continue
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
acc_opt = f'acc={int(accumulate)}'
for test_alias in (False, True):
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True)
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True)
func_name = f'bf16_gemm_{major_opt.lower() if test_alias else "nt"}'
if test_alias:
a = a if major_a.is_k_major() else a.T
@@ -31,28 +36,22 @@ def test_gemm() -> None:
diff = calc_diff(d, ref_d)
assert diff < 0.0001, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, '
f'{diff:.5f}, alias={test_alias}')
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True)
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True)
cublas_t = 0
t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True)
if accumulate == 0 and out_dtype == torch.bfloat16:
# noinspection PyBroadException
try:
cublas_t = bench_kineto(lambda: a @ b.T, 'nvjet', suppress_kineto_output=True)
except Exception:
pass
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, {out_opt}, {acc_opt}): '
f'{t * 1e6:4.0f} us | '
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True)
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): '
f'{t * 1e6:5.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
f'{cublas_t / t:.2f}x cuBLAS')
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
print()
def test_m_grouped_gemm_contiguous() -> None:
print('Testing m-grouped contiguous GEMM:')
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(use_bf16=True):
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
@@ -85,7 +84,7 @@ def test_m_grouped_gemm_masked() -> None:
print('Testing m-grouped masked GEMM:')
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked():
for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.bfloat16):
# Test correctness
for i in range(10):
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True)
@@ -111,6 +110,27 @@ def test_m_grouped_gemm_masked() -> None:
print()
def test_cublaslt_gemm() -> None:
print('Testing cuBLASLt GEMM:')
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
acc_opt = f'acc={int(accumulate)}'
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True)
deep_gemm.cublaslt_gemm_nt(a, b, d, c=c)
diff = calc_diff(d, ref_d)
assert diff < 5e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})'
t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), 'nvjet', suppress_kineto_output=True,)
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): '
f'{t * 1e6:5.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
@@ -121,5 +141,9 @@ if __name__ == '__main__':
print(f' > {deep_gemm.__path__}\n')
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
# TODO: support SM100
if get_arch_major() == 9:
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
test_cublaslt_gemm()

85
tests/test_einsum.py Normal file
View File

@@ -0,0 +1,85 @@
import random
import torch
import deep_gemm
from deep_gemm.testing import (
bench, bench_kineto,
calc_diff, count_bytes
)
def test_bmk_bnk_mn() -> None:
print('Testing "bmk, bnk -> mn":')
for s in (129, 4096, 8192):
for m, n, k in [(128, 384, 128), (256, 256, 256), (384, 128, 384)]:
for dtype in (torch.float, torch.bfloat16):
a = torch.randn((s, m, k), dtype=torch.bfloat16, device='cuda')
b = torch.randn((s, n, k), dtype=torch.bfloat16, device='cuda')
d = torch.randn((m, n), dtype=dtype, device='cuda')
c = d if dtype == torch.float else None
# Test correctness
ref_d = (c if dtype == torch.float else 0) + torch.bmm(a.float(), b.float().mT).sum(0)
deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c)
assert calc_diff(d, ref_d) < 1e-5
t = bench_kineto(lambda: deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c), 'bmn_bnk_mn_gemm_impl', suppress_kineto_output=True)
print(f' > Perf (b={s:4.0f}, {m=}, {n=}, {k=}, {"FP32" if dtype == torch.float else "BF16"}): ',
f'{t * 1e6:4.0f} us | '
f'{2 * s * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b) + (d.numel() * 4)) / 1e9 / t:4.0f} GB/s')
print()
def test_bhr_hdr_bhd():
print('Testing "bhr, hdr -> bhd":')
for b in (128, 4096, 8192):
for h, r, d in [(128, 512, 128)]:
x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16)
fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16)
y = fy[:, :, :r]
ref_z = torch.einsum('bhr,hdr->bhd', x, y)
z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16)
deep_gemm.einsum('bhr,hdr->bhd', x, y, z)
assert calc_diff(z, ref_z) < 1e-10
t = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z), 'nvjet', suppress_kineto_output=True)
print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ',
f'{t * 1e6:4.0f} us | '
f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | '
f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s')
print()
def test_bhd_hdr_bhr():
print('Testing "bhd, hdr -> bhr":')
for b in (128, 4096, 8192):
for h, r, d in [(128, 512, 128)]:
x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16)
fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16)
y = fy[:, :, :r]
ref_z = torch.einsum('bhd,hdr->bhr', x, y)
z = torch.empty((b, h, r), device='cuda', dtype=torch.bfloat16)
deep_gemm.einsum('bhd,hdr->bhr', x, y, z)
assert calc_diff(z, ref_z) < 1e-10
t = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z), 'nvjet', suppress_kineto_output=True)
print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ',
f'{t * 1e6:4.0f} us | '
f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | '
f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s')
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)
random.seed(0)
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_bmk_bnk_mn()
test_bhr_hdr_bhd()
test_bhd_hdr_bhr()

View File

@@ -10,7 +10,7 @@ from deep_gemm.testing import (
)
from generators import (
KernelType, get_ue8m0_usage,
KernelType, get_arch_major, get_ue8m0_usage,
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
)
@@ -18,7 +18,7 @@ from generators import (
def test_gemm() -> None:
print('Testing GEMM:')
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal():
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
@@ -26,42 +26,35 @@ def test_gemm() -> None:
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
recipe = (1, 1, 128) if kernel_type.is_1d1d() and accumulate else None
for test_alias in (False, True):
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0)
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}'
if test_alias:
a = a if major_a.is_k_major() else (a[0].T, a[1].T)
b = b if major_b.is_k_major() else (b[0].T, b[1].T)
assert a[0].is_contiguous() and b[0].is_contiguous()
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe)
diff = calc_diff(d, ref_d)
assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, '
f'{diff:.5f}, alias={test_alias}')
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0)
# Test launch overhead
launch_start_t = time.time_ns()
deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
launch_end_t = time.time_ns()
torch.cuda.synchronize()
# noinspection PyShadowingNames
def test_func():
deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
f'launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s')
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe),
'fp8_gemm', suppress_kineto_output=True)
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True)
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
f'{t * 1e6:4.0f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
print()
def test_m_grouped_gemm_contiguous() -> None:
print('Testing m-grouped contiguous GEMM:')
for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous():
for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
@@ -86,7 +79,7 @@ def test_m_grouped_gemm_contiguous() -> None:
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}): '
print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
@@ -97,7 +90,7 @@ def test_m_grouped_gemm_masked() -> None:
print('Testing m-grouped masked GEMM:')
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked():
for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.float8_e4m3fn):
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
@@ -130,26 +123,31 @@ def test_m_grouped_gemm_masked() -> None:
def test_k_grouped_gemm_contiguous() -> None:
print('Testing k-grouped contiguous GEMM:')
for num_groups, m, n, ks, expected_k_per_group in enumerate_k_grouped_contiguous():
k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \
else deep_gemm.k_grouped_fp8_gemm_tn_contiguous
for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous():
use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D)
for test_empty_groups in (False, True):
new_ks = copy.deepcopy(ks)
if test_empty_groups:
if test_empty_groups and len(ks) > 1:
new_ks[random.randint(0, num_groups - 1)] = 0
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, new_ks, use_ue8m0=use_ue8m0)
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0)
new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda')
deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, new_ks, new_ks_tensor, c=c)
diff = calc_diff(d, ref_d)
assert diff < 0.001, f'{m=}, {n=}, {k=}, {i=}, {diff:.5f}'
k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c)
do_check = True
if do_check:
diff = calc_diff(d, ref_d)
assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}'
# Test performance
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, ks, use_ue8m0=use_ue8m0)
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0)
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
# noinspection PyShadowingNames
def test_func():
deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=c)
k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): '