Make various updates and fixes (#198)
This commit is contained in:
77
csrc/apis/attention.hpp
Normal file
77
csrc/apis/attention.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
#pragma once
|
||||
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
|
||||
#include "layout.hpp"
|
||||
|
||||
namespace deep_gemm::attention {
|
||||
|
||||
static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::tuple<int, int, int> &head_splits,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
if (fp8_requires_k_major()) {
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
}
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a.first);
|
||||
const auto& [n , k_] = get_shape<2>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
|
||||
// Check head splits and N
|
||||
const auto& [left, mid, right] = head_splits;
|
||||
DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid);
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
const auto& epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"),
|
||||
py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::attention
|
||||
115
csrc/apis/einsum.hpp
Normal file
115
csrc/apis/einsum.hpp
Normal file
@@ -0,0 +1,115 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
#include "../utils/format.hpp"
|
||||
#include "../utils/layout.hpp"
|
||||
|
||||
#include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp"
|
||||
#include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp"
|
||||
#include "../jit_kernels/impls/smxx_cublaslt.hpp"
|
||||
|
||||
namespace deep_gemm::einsum {
|
||||
|
||||
static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c) {
|
||||
// Currently FP32 only support the accumulated expression
|
||||
if (d.scalar_type() == torch::kFloat) {
|
||||
DG_HOST_ASSERT(c->data_ptr() == d.data_ptr() and c->sizes() == d.sizes() and c->strides() == d.strides());
|
||||
} else {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
|
||||
const auto& workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32));
|
||||
DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(),
|
||||
c10::cuda::getCurrentCUDAStream()));
|
||||
bmk_bnk_mn(a, b, workspace, workspace);
|
||||
|
||||
// This line has an implicit FP32-to-BF16 casting
|
||||
d.copy_(workspace);
|
||||
return;
|
||||
}
|
||||
|
||||
DG_HOST_ASSERT(a.is_contiguous());
|
||||
DG_HOST_ASSERT(b.is_contiguous());
|
||||
DG_HOST_ASSERT(d.is_contiguous());
|
||||
|
||||
const auto& [s , m, k ] = get_shape<3>(a);
|
||||
const auto& [s_, n, k_] = get_shape<3>(b);
|
||||
DG_HOST_ASSERT(s == s_ and k == k_);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
|
||||
} else if (arch_major == 10) {
|
||||
sm100_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) {
|
||||
const auto& [b , h , r ] = get_shape<3>(A);
|
||||
const auto& [h_, d , r_] = get_shape<3>(B);
|
||||
const auto& [b_, h__, d_] = get_shape<3>(D);
|
||||
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
|
||||
|
||||
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
|
||||
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
|
||||
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
|
||||
|
||||
cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d);
|
||||
}
|
||||
|
||||
static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) {
|
||||
const auto& [b , h , d ] = get_shape<3>(A);
|
||||
const auto& [h_, d_ , r ] = get_shape<3>(B);
|
||||
const auto& [b_, h__, r_] = get_shape<3>(D);
|
||||
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
|
||||
|
||||
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
|
||||
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
|
||||
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
|
||||
|
||||
cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d);
|
||||
}
|
||||
|
||||
static void einsum(const std::string& expr,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c) {
|
||||
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c->scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
||||
}
|
||||
|
||||
// Some hardcoded Einstein sum kernels
|
||||
// TODO: support any expression
|
||||
// TODO: canonicalize expression
|
||||
if (expr == "bmk,bnk->mn") {
|
||||
bmk_bnk_mn(a, b, d, c);
|
||||
} else if (expr == "bhr,hdr->bhd") {
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
bhr_hdr_bhd(a, b, d);
|
||||
} else if (expr == "bhd,hdr->bhr") {
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
bhd_hdr_bhr(a, b, d);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr));
|
||||
}
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
m.def("einsum", &einsum,
|
||||
py::arg("expr"), py::arg("a"), py::arg("b"),
|
||||
py::arg("d"), py::arg("c") = std::nullopt);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::einsum
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
@@ -52,13 +53,18 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
if (std::get<1>(recipe.value()) == 1) {
|
||||
sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
}
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
@@ -261,6 +267,60 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
|
||||
}
|
||||
}
|
||||
|
||||
static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::string& compiled_dims) {
|
||||
// Must be 1D1D kernel
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
|
||||
// Shape checks
|
||||
const auto& [num_groups, m, n] = get_shape<3>(d);
|
||||
const auto& sum_mk = a.first.numel();
|
||||
const auto& sum_nk = b.first.numel();
|
||||
int sum_k = 0;
|
||||
for (const auto& k: ks)
|
||||
sum_k += k;
|
||||
DG_HOST_ASSERT(sum_mk == m * sum_k);
|
||||
DG_HOST_ASSERT(sum_nk == n * sum_k);
|
||||
|
||||
// Contiguity checks
|
||||
DG_HOST_ASSERT(a.first.is_contiguous());
|
||||
DG_HOST_ASSERT(b.first.is_contiguous());
|
||||
DG_HOST_ASSERT(d.is_contiguous());
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().is_contiguous());
|
||||
}
|
||||
|
||||
// Do nothing if empty
|
||||
if (std::accumulate(ks.begin(), ks.end(), 0) == 0)
|
||||
return;
|
||||
|
||||
// Transform SF with padding
|
||||
const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
|
||||
const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
|
||||
|
||||
// Allocate tensormap buffer
|
||||
// `4` means the double buffering for both A and B operands (2 * 2)
|
||||
const auto& num_sms = device_runtime->get_num_sms();
|
||||
const auto& tensor_map_buffer = torch::empty({num_sms * 4 * static_cast<int>(sizeof(CUtensorMap))},
|
||||
a.first.options().dtype(torch::kByte));
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
|
||||
cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void bf16_gemm_nt(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
@@ -403,6 +463,43 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T
|
||||
}
|
||||
}
|
||||
|
||||
static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a);
|
||||
const auto& major_b = get_major_type_ab(b);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a);
|
||||
const auto& [n , k_] = get_shape<2>(b);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
|
||||
if (c.has_value())
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == d.scalar_type());
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0 or n == 0)
|
||||
return;
|
||||
|
||||
cublaslt_gemm(a, b, c, d, m, n, k, major_a, major_b);
|
||||
}
|
||||
|
||||
static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
||||
cublaslt_gemm_nt(a, b.transpose(0, 1), d, c);
|
||||
}
|
||||
|
||||
static void cublaslt_gemm_tn(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
||||
cublaslt_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c);
|
||||
}
|
||||
|
||||
static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
||||
cublaslt_gemm_nt(a.transpose(0, 1), b, d, c);
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
// FP8 GEMMs
|
||||
m.def("fp8_gemm_nt", &fp8_gemm_nt,
|
||||
@@ -442,6 +539,11 @@ static void register_apis(pybind11::module_& m) {
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
m.def("k_grouped_fp8_gemm_nt_contiguous", &k_grouped_fp8_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
|
||||
// BF16 GEMMs
|
||||
m.def("bf16_gemm_nt", &bf16_gemm_nt,
|
||||
@@ -466,6 +568,16 @@ static void register_apis(pybind11::module_& m) {
|
||||
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
|
||||
|
||||
// cuBLASLt GEMMs
|
||||
m.def("cublaslt_gemm_nt", &cublaslt_gemm_nt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
||||
m.def("cublaslt_gemm_nn", &cublaslt_gemm_nn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
||||
m.def("cublaslt_gemm_tn", &cublaslt_gemm_tn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
||||
m.def("cublaslt_gemm_tt", &cublaslt_gemm_tt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::gemm
|
||||
|
||||
@@ -56,14 +56,14 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
|
||||
|
||||
// FP32 on SM90
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
return get_mn_major_tma_aligned_tensor(sf);
|
||||
|
||||
// FP32 on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
|
||||
|
||||
// INT on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
if (sf.scalar_type() == torch::kInt and arch_major == 10)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown cases");
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
// GEMM kernels
|
||||
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
|
||||
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
|
||||
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
|
||||
|
||||
// Einsum kernels
|
||||
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
|
||||
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
|
||||
|
||||
// Layout kernels
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
@@ -11,8 +12,28 @@ class DeviceRuntime {
|
||||
int num_sms = 0, tc_util = 0;
|
||||
std::shared_ptr<cudaDeviceProp> cached_prop;
|
||||
|
||||
// cuBLASLt utils
|
||||
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
|
||||
cublasLtHandle_t cublaslt_handle{};
|
||||
std::shared_ptr<torch::Tensor> cublaslt_workspace;
|
||||
|
||||
public:
|
||||
explicit DeviceRuntime() = default;
|
||||
explicit DeviceRuntime() {
|
||||
cublaslt_workspace = std::make_shared<torch::Tensor>(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)));
|
||||
DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle));
|
||||
}
|
||||
|
||||
~DeviceRuntime() noexcept(false) {
|
||||
DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle));
|
||||
}
|
||||
|
||||
cublasLtHandle_t get_cublaslt_handle() const {
|
||||
return cublaslt_handle;
|
||||
}
|
||||
|
||||
torch::Tensor get_cublaslt_workspace() const {
|
||||
return *cublaslt_workspace;
|
||||
}
|
||||
|
||||
std::shared_ptr<cudaDeviceProp> get_prop() {
|
||||
if (cached_prop == nullptr)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <deep_gemm/common/types.hpp>
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../../utils/layout.hpp"
|
||||
|
||||
@@ -80,18 +82,19 @@ static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
|
||||
return divisible and num_sms % num_multicast == 0;
|
||||
}
|
||||
|
||||
static int get_swizzle_mode(const int& block_size, const int& elem_size) {
|
||||
template <typename size_type_t>
|
||||
static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) {
|
||||
// `> 0` means interleaving
|
||||
// 16B actually means non-swizzling (but interleaving)
|
||||
for (const int& mode: {128, 64, 32, 16}) {
|
||||
if ((block_size * elem_size) % mode == 0)
|
||||
if ((block_size * static_cast<int>(elem_size)) % mode == 0)
|
||||
return mode;
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unreachable");
|
||||
}
|
||||
|
||||
template <typename ArchSpec>
|
||||
static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
|
||||
static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const KernelType& kernel_type,
|
||||
const int& m, const int& n, const int& k,
|
||||
const int& block_m, const int& block_n, const int& block_k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
@@ -104,7 +107,7 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
|
||||
const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n);
|
||||
const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size);
|
||||
const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size);
|
||||
const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size);
|
||||
const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0;
|
||||
|
||||
// Different archs have different epilogue pipelines
|
||||
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);
|
||||
@@ -121,9 +124,11 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
|
||||
// M-barriers and tensor memory pointers
|
||||
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages);
|
||||
const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size();
|
||||
const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type);
|
||||
|
||||
// Sum them up
|
||||
int smem_size = 0;
|
||||
smem_size += smem_tensor_map;
|
||||
smem_size += smem_cd;
|
||||
smem_size += num_stages * smem_a_per_stage;
|
||||
smem_size += num_stages * smem_b_per_stage;
|
||||
@@ -151,15 +156,12 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
|
||||
|
||||
// Select M/N block sizes
|
||||
// TODO: support `% 16 == 8` block size on SM90
|
||||
auto block_ms = std::vector{64, 128, 256};
|
||||
if (gemm_type == GemmType::MGroupedContiguous)
|
||||
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
|
||||
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
|
||||
block_ms = std::vector{64, 128};
|
||||
std::vector<int> block_ns;
|
||||
for (int i = 16; i <= 256; i += 16)
|
||||
block_ns.push_back(i);
|
||||
const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype);
|
||||
|
||||
// K block size is selected in a fixed manner
|
||||
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
|
||||
@@ -214,9 +216,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0);
|
||||
|
||||
// Decide the number of TMA multicasts and whether broadcast on A
|
||||
MulticastConfig best_multicast_config = {1, true};
|
||||
MulticastConfig best_multicast_config = {1, false};
|
||||
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
|
||||
gemm_type, m, n, best_block_m, best_block_n, num_sms);
|
||||
gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms);
|
||||
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
|
||||
bool order[2] = {false, true};
|
||||
if (best_block_m > best_block_n)
|
||||
@@ -232,11 +234,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
constexpr int smem_capacity = ArchSpec::smem_capacity;
|
||||
int best_num_stages = 0;
|
||||
SharedMemoryConfig best_smem_config;
|
||||
for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) {
|
||||
for (int num_stages = 12; num_stages > 0; -- num_stages) {
|
||||
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
|
||||
continue;
|
||||
|
||||
best_smem_config = get_smem_config<ArchSpec>(kernel_type,
|
||||
best_smem_config = get_smem_config<ArchSpec>(gemm_type, kernel_type,
|
||||
m, n, k,
|
||||
best_block_m, best_block_n, block_k,
|
||||
major_a, major_b,
|
||||
|
||||
@@ -12,6 +12,15 @@ namespace deep_gemm {
|
||||
struct SM100ArchSpec {
|
||||
static constexpr int smem_capacity = 232448;
|
||||
|
||||
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
|
||||
// 16 is for better SM usage
|
||||
// Stride 32 is due to low-performance swizzle-16/32B
|
||||
std::vector<int> candidates = {16};
|
||||
for (int i = 32; i <= 256; i += 32)
|
||||
candidates.push_back(i);
|
||||
return candidates;
|
||||
}
|
||||
|
||||
static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) {
|
||||
return block_m / (config.is_multicast_on_a ? config.num_multicast : 1);
|
||||
}
|
||||
@@ -29,6 +38,10 @@ struct SM100ArchSpec {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
|
||||
const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) {
|
||||
constexpr int num_utccp_aligned_elems = 128;
|
||||
@@ -86,7 +99,7 @@ struct SM100ArchSpec {
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
|
||||
const int& m, const int& n, const int& block_m, const int& block_n,
|
||||
const int& num_sms) {
|
||||
// TODO: support other layouts
|
||||
@@ -138,12 +151,17 @@ struct SM100ArchSpec {
|
||||
// TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers
|
||||
// NOTES: 1D2D kernel will not use the with-SF full barriers
|
||||
// NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
|
||||
return num_stages * 8 * 3 + 2 * 8 * 2;
|
||||
// NOTES: the last barrier is for tensor core utilization control
|
||||
return num_stages * 8 * 3 + 2 * 8 * 2 + 8;
|
||||
}
|
||||
|
||||
static int get_tmem_ptr_smem_size() {
|
||||
return 4;
|
||||
}
|
||||
|
||||
static int get_tensormap_smem_size(const GemmType& gemm_type) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -11,6 +11,15 @@ namespace deep_gemm {
|
||||
struct SM90ArchSpec {
|
||||
static constexpr int smem_capacity = 232448;
|
||||
|
||||
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
|
||||
// Avoid bank conflicts for FP32 output
|
||||
const auto& start = cd_dtype == torch::kFloat ? 8 : 16;
|
||||
std::vector<int> candidates;
|
||||
for (int i = start; i <= 256; i += 16)
|
||||
candidates.push_back(i);
|
||||
return candidates;
|
||||
}
|
||||
|
||||
static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) {
|
||||
return block_m;
|
||||
}
|
||||
@@ -19,26 +28,35 @@ struct SM90ArchSpec {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static int get_cd_store_block_m(const int& block_m) {
|
||||
return block_m;
|
||||
static int get_cd_store_block_m(const int& block_m, const bool& single_warpgroup_sync = false) {
|
||||
constexpr int wgmma_m = 64;
|
||||
return single_warpgroup_sync ? wgmma_m : block_m;
|
||||
}
|
||||
|
||||
static int get_cd_store_block_n(const int& block_n) {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
|
||||
return cd_dtype != torch::kFloat;
|
||||
}
|
||||
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// FP32 output does not support `block_m == 256`
|
||||
// SM90 FP32 output does not support `block_m == 256`
|
||||
if (cd_dtype == at::kFloat and block_m == 256)
|
||||
return false;
|
||||
|
||||
// TODO: more general block N selection
|
||||
// Must be some fixed block N selections
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 and block_n != 152))
|
||||
return false;
|
||||
// Avoid large C/D shared memory for FP32 output
|
||||
// Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel)
|
||||
if (block_n > 128 and cd_dtype == torch::kFloat) {
|
||||
if (kernel_type == KernelType::Kernel1D1D and block_n > 152)
|
||||
return false;
|
||||
if (kernel_type == KernelType::KernelNoSF and block_n > 200)
|
||||
return false;
|
||||
}
|
||||
|
||||
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
|
||||
// Or too many register spills
|
||||
@@ -66,9 +84,13 @@ struct SM90ArchSpec {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
|
||||
const int& m, const int& n, const int& block_m, const int& block_n,
|
||||
const int& num_sms) {
|
||||
// Disable multicast when the number of k-groups is large (a heuristic)
|
||||
if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4)
|
||||
return {false, false};
|
||||
|
||||
return {
|
||||
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
|
||||
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
|
||||
@@ -96,9 +118,10 @@ struct SM90ArchSpec {
|
||||
|
||||
int smem_sfa_per_stage = block_m * static_cast<int>(sizeof(float));
|
||||
int smem_sfb_per_stage = 0;
|
||||
// TODO: figure out here
|
||||
if (kernel_type == KernelType::Kernel1D1D)
|
||||
smem_sfb_per_stage = align(block_n * 4, block_k);
|
||||
if (kernel_type == KernelType::Kernel1D1D) {
|
||||
// NOTES: `128` is for 2D TMA alignment requirement
|
||||
smem_sfb_per_stage = align(block_n * 4, 128);
|
||||
}
|
||||
return {smem_sfa_per_stage, smem_sfb_per_stage};
|
||||
}
|
||||
|
||||
@@ -109,13 +132,16 @@ struct SM90ArchSpec {
|
||||
}
|
||||
|
||||
static int get_barrier_smem_size(const int& num_stages) {
|
||||
// For 1D1D kernels, there is an extra barrier for accumulation
|
||||
return (num_stages + 1) * 8 * 2;
|
||||
return num_stages * 8 * 2;
|
||||
}
|
||||
|
||||
static int get_tmem_ptr_smem_size() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int get_tensormap_smem_size(const GemmType& gemm_type) {
|
||||
return gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast<int>(sizeof(CUtensorMap)) : 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
12
csrc/jit_kernels/impls/epilogue.hpp
Normal file
12
csrc/jit_kernels/impls/epilogue.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
static std::string get_default_epilogue_type(const std::optional<std::string>& epilogue_type) {
|
||||
return epilogue_type.value_or("EpilogueIdentity");
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -4,6 +4,8 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "../../utils/system.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
@@ -51,7 +53,11 @@ static std::string to_string(const at::ScalarType& dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) {
|
||||
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype,
|
||||
const bool& allow_tf32) {
|
||||
if (allow_tf32 and dtype == torch::kFloat)
|
||||
return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32;
|
||||
|
||||
switch (dtype) {
|
||||
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
@@ -61,9 +67,14 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) {
|
||||
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
|
||||
if (base != 0) {
|
||||
DG_HOST_ASSERT(base == 32 and mode == 128);
|
||||
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
|
||||
}
|
||||
|
||||
switch (mode) {
|
||||
case 0: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case 0:
|
||||
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
@@ -76,7 +87,8 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
int gmem_inner_dim, int gmem_outer_dim,
|
||||
int smem_inner_dim, int smem_outer_dim,
|
||||
const int& gmem_outer_stride,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
const auto& elem_size = static_cast<int>(t.element_size());
|
||||
if (swizzle_mode != 0)
|
||||
smem_inner_dim = swizzle_mode / elem_size;
|
||||
@@ -87,14 +99,42 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
|
||||
const cuuint32_t elem_strides[2] = {1, 1};
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n",
|
||||
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d\n",
|
||||
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
|
||||
gmem_outer_stride, swizzle_mode, elem_size);
|
||||
gmem_outer_stride, swizzle_mode, swizzle_base, elem_size);
|
||||
}
|
||||
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()),
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
|
||||
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode),
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
|
||||
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
|
||||
const int& gmem_dim_0, const int& gmem_dim_1, const int& gmem_dim_2,
|
||||
const int& smem_dim_0, const int& smem_dim_1, const int& smem_dim_2,
|
||||
const int& gmem_stride_0, const int& gmem_stride_1,
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
const auto& elem_size = static_cast<int>(t.element_size());
|
||||
if (swizzle_mode != 0)
|
||||
DG_HOST_ASSERT(smem_dim_0 == swizzle_mode / elem_size);
|
||||
|
||||
CUtensorMap tensor_map;
|
||||
const cuuint64_t gmem_dims[3] = {static_cast<cuuint64_t>(gmem_dim_0), static_cast<cuuint64_t>(gmem_dim_1), static_cast<cuuint64_t>(gmem_dim_2),};
|
||||
const cuuint32_t smem_dims[3] = {static_cast<cuuint32_t>(smem_dim_0), static_cast<cuuint32_t>(smem_dim_1), static_cast<cuuint32_t>(smem_dim_2)};
|
||||
const cuuint64_t gmem_strides[2] = {static_cast<cuuint64_t>(gmem_stride_0 * elem_size), static_cast<cuuint64_t>(gmem_stride_1 * elem_size)};
|
||||
const cuuint32_t elem_strides[3] = {1, 1, 1};
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
printf("Making 3D TMA desc: global memory: %d %d %d, shared memory: %d %d %d, outer stride: %d %d, swizzle: %d, elem size: %d\n",
|
||||
gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2,
|
||||
gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size);
|
||||
}
|
||||
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
|
||||
3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
|
||||
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
||||
return tensor_map;
|
||||
}
|
||||
@@ -105,7 +145,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
|
||||
const int& block_m, const int& block_k,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
if (num_groups > 1)
|
||||
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
|
||||
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
|
||||
@@ -114,7 +155,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
|
||||
gmem_inner_dim, gmem_outer_dim,
|
||||
smem_inner_dim, smem_outer_dim,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
@@ -123,7 +165,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
const int& block_n, const int& block_k,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
|
||||
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
|
||||
|
||||
@@ -132,7 +175,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
gmem_inner_dim, gmem_outer_dim * num_groups,
|
||||
smem_inner_dim, smem_outer_dim,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
|
||||
@@ -140,15 +184,16 @@ static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
|
||||
const int& block_m, const int& block_n,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
|
||||
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
|
||||
return make_tma_2d_desc(t,
|
||||
shape_n, shape_m * num_groups,
|
||||
block_n, block_m,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
@@ -156,7 +201,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
int shape_mn, int shape_k,
|
||||
const int& block_mn, const int& block_k,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
|
||||
|
||||
// TODO: maybe swizzle SF as well
|
||||
@@ -167,7 +213,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
|
||||
block_mn, 1,
|
||||
shape_mn,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -42,7 +42,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
@@ -56,7 +56,7 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.num_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
@@ -80,8 +80,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
// TODO: test other Ks
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
const auto& aligned_k = align(k, 64);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
@@ -122,7 +121,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
|
||||
// Launch
|
||||
const SM100BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
|
||||
137
csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp
Normal file
137
csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp
Normal file
@@ -0,0 +1,137 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100BmkBnkMnRuntime final: public LaunchRuntime<SM100BmkBnkMnRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int s, m, n, k;
|
||||
int block_m, block_n, block_k;
|
||||
int split_factor;
|
||||
int swizzle_ab_mode, swizzle_cd_mode;
|
||||
int num_stages;
|
||||
int num_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_bmn_bnk_mn_gemm_impl<
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
args.m, args.n, args.k,
|
||||
args.block_m, args.block_n, args.block_k,
|
||||
args.split_factor,
|
||||
args.swizzle_ab_mode, args.swizzle_cd_mode,
|
||||
args.num_stages, args.num_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.s, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a,
|
||||
const torch::Tensor &b,
|
||||
const torch::Tensor &d,
|
||||
const int &s, const int &m, const int &n, const int &k) {
|
||||
constexpr int block_m = 128;
|
||||
constexpr int block_n = 128;
|
||||
constexpr int block_k = 64;
|
||||
constexpr int num_threads = 128;
|
||||
DG_HOST_ASSERT(k % block_k == 0);
|
||||
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
|
||||
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
|
||||
|
||||
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
|
||||
const int swizzle_cd_mode = get_swizzle_mode(block_n, static_cast<int>(d.element_size()));
|
||||
|
||||
// Get best config
|
||||
const int num_sms = device_runtime->get_num_sms();
|
||||
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
|
||||
const int num_sk_blocks = s * (k / block_k);
|
||||
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
|
||||
|
||||
// Select best number of stages
|
||||
// NOTES: we select 4 as start, as it is tested to be faster than values > 4
|
||||
int num_stages = 4, smem_size = 0;
|
||||
while (true) {
|
||||
const int& smem_cd = block_m * swizzle_cd_mode * 2;
|
||||
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_barrier = SM100ArchSpec::get_barrier_smem_size(num_stages);
|
||||
const int& smem_tmem_ptr = SM100ArchSpec::get_tmem_ptr_smem_size();
|
||||
|
||||
smem_size = 0;
|
||||
smem_size += smem_cd;
|
||||
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
|
||||
smem_size += smem_barrier;
|
||||
smem_size += smem_tmem_ptr;
|
||||
if (smem_size <= SM100ArchSpec::smem_capacity)
|
||||
break;
|
||||
|
||||
-- num_stages;
|
||||
}
|
||||
DG_HOST_ASSERT(num_stages > 0);
|
||||
|
||||
// Print configs
|
||||
if (get_env("DG_JIT_DEBUG", 0)) {
|
||||
printf("S: %d, M: %d, N: %d, K: %d -> "
|
||||
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
|
||||
"stages: %d, shared memory: %d, swizzle AB: %d, swizzle CD: %d\n",
|
||||
s, m, n, k, block_m, block_n, block_k, split_factor,
|
||||
num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode);
|
||||
}
|
||||
|
||||
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
|
||||
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
|
||||
const auto& tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode);
|
||||
|
||||
const SM100BmkBnkMnRuntime::Args& args = {
|
||||
.s = s, .m = m, .n = n, .k = k,
|
||||
.block_m = block_m, .block_n = block_n, .block_k = block_k,
|
||||
.split_factor = split_factor,
|
||||
.swizzle_ab_mode = swizzle_ab_mode,
|
||||
.swizzle_cd_mode = swizzle_cd_mode,
|
||||
.num_stages = num_stages,
|
||||
.num_threads = num_threads,
|
||||
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100BmkBnkMnRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code);
|
||||
SM100BmkBnkMnRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -9,6 +9,8 @@
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
|
||||
#include "epilogue.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
@@ -18,6 +20,7 @@ public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
const std::optional<std::string>& epilogue_type;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
@@ -44,11 +47,12 @@ static void __instantiate_kernel() {{
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}, {}
|
||||
{}, {}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -57,11 +61,12 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.num_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype));
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
|
||||
get_default_epilogue_type(args.epilogue_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -80,7 +85,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D1D,
|
||||
@@ -99,7 +105,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
@@ -129,6 +135,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -186,6 +193,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -244,6 +252,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -324,6 +333,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
|
||||
.m = m, .n = n, .k = sum_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
|
||||
@@ -18,6 +18,7 @@ public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
const std::optional<std::string>& epilogue_type;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
@@ -46,7 +47,8 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
{}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -59,7 +61,8 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype));
|
||||
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype),
|
||||
get_default_epilogue_type(args.epilogue_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -78,7 +81,8 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
@@ -98,7 +102,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
@@ -111,6 +115,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -164,6 +169,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -218,6 +224,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const t
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
|
||||
@@ -41,7 +41,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -53,7 +53,8 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -73,10 +74,10 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& aligned_k = align(k, 64);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
@@ -102,7 +103,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
|
||||
131
csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp
Normal file
131
csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp
Normal file
@@ -0,0 +1,131 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90BmkBnkMnRuntime final: public LaunchRuntime<SM90BmkBnkMnRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int s, m, n, k;
|
||||
int block_m, block_n, block_k;
|
||||
int split_factor;
|
||||
int num_stages;
|
||||
int num_tma_threads, num_math_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
float* d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_bmn_bnk_mn_gemm_impl<
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
args.m, args.n, args.k,
|
||||
args.block_m, args.block_n, args.block_k,
|
||||
args.split_factor,
|
||||
args.num_stages,
|
||||
args.num_tma_threads, args.num_math_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.s, args.tensor_map_a, args.tensor_map_b, args.d));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a,
|
||||
const torch::Tensor &b,
|
||||
const torch::Tensor &d,
|
||||
const int &s, const int &m, const int &n, const int &k) {
|
||||
constexpr int block_m = 128;
|
||||
constexpr int block_n = 128;
|
||||
constexpr int block_k = 64;
|
||||
constexpr int num_tma_threads = 128;
|
||||
constexpr int num_math_threads = 256;
|
||||
DG_HOST_ASSERT(k % block_k == 0);
|
||||
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
|
||||
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
|
||||
|
||||
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
|
||||
DG_HOST_ASSERT(swizzle_ab_mode == 128);
|
||||
|
||||
// Get best config
|
||||
const int num_sms = device_runtime->get_num_sms();
|
||||
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
|
||||
const int num_sk_blocks = s * (k / block_k);
|
||||
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
|
||||
|
||||
// Select best number of stages
|
||||
int num_stages = 4, smem_size = 0;
|
||||
while (true) {
|
||||
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_barrier = SM90ArchSpec::get_barrier_smem_size(num_stages);
|
||||
|
||||
smem_size = 0;
|
||||
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
|
||||
smem_size += smem_barrier;
|
||||
|
||||
if (smem_size <= SM90ArchSpec::smem_capacity)
|
||||
break;
|
||||
|
||||
-- num_stages;
|
||||
}
|
||||
DG_HOST_ASSERT(num_stages > 0);
|
||||
|
||||
// Print configs
|
||||
if (get_env("DG_JIT_DEBUG", 0)) {
|
||||
printf("S: %d, M: %d, N: %d, K: %d -> "
|
||||
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
|
||||
"stages: %d, shared memory: %d, swizzle AB: %d\n",
|
||||
s, m, n, k, block_m, block_n, block_k, split_factor,
|
||||
num_stages, smem_size, swizzle_ab_mode);
|
||||
}
|
||||
|
||||
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
|
||||
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
|
||||
|
||||
const SM90BmkBnkMnRuntime::Args& args = {
|
||||
.s = s, .m = m, .n = n, .k = k,
|
||||
.block_m = block_m, .block_n = block_n, .block_k = block_k,
|
||||
.split_factor = split_factor,
|
||||
.num_stages = num_stages,
|
||||
.num_tma_threads = num_tma_threads,
|
||||
.num_math_threads = num_math_threads,
|
||||
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.d = d.data_ptr<float>()
|
||||
};
|
||||
const auto& code = SM90BmkBnkMnRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code);
|
||||
SM90BmkBnkMnRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
214
csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp
Normal file
214
csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp
Normal file
@@ -0,0 +1,214 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime<SM90FP8Gemm1D1DRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void *gmem_a_ptr;
|
||||
void *gmem_b_ptr;
|
||||
void *grouped_layout;
|
||||
void *tensor_map_buffer;
|
||||
CUtensorMap tensor_map_a_base;
|
||||
CUtensorMap tensor_map_b_base;
|
||||
CUtensorMap tensor_map_sfa;
|
||||
CUtensorMap tensor_map_sfb;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d1d_impl<
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.num_groups,
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.gemm_config.num_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.gmem_a_ptr, args.gmem_b_ptr,
|
||||
args.grouped_layout,
|
||||
args.tensor_map_buffer,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a_base, args.tensor_map_b_base,
|
||||
args.tensor_map_sfa, args.tensor_map_sfb,
|
||||
args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D1D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k, k, 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k, k, 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
0);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.gmem_a_ptr = nullptr,
|
||||
.gmem_b_ptr = nullptr,
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_buffer = nullptr,
|
||||
.tensor_map_a_base = tensor_map_a,
|
||||
.tensor_map_b_base = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
|
||||
|
||||
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n,
|
||||
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
|
||||
const torch::Tensor& tensor_map_buffer,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
// Get config using max K for better performance
|
||||
const auto& num_groups = static_cast<int>(ks.size());
|
||||
const auto& max_k = *std::max_element(ks.begin(), ks.end());
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, max_k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
|
||||
int first_k = 0, sum_k = 0, sum_sf_k = 0;
|
||||
for (int i = 0; i < num_groups; ++ i) {
|
||||
if (first_k == 0 and ks[i] != 0)
|
||||
first_k = ks[i];
|
||||
sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128);
|
||||
DG_HOST_ASSERT(ks[i] % 128 == 0);
|
||||
}
|
||||
const auto& tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k, first_k, 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k, first_k, 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128,
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = sum_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.gmem_a_ptr = a.data_ptr(),
|
||||
.gmem_b_ptr = b.data_ptr(),
|
||||
.grouped_layout = ks_tensor.data_ptr(),
|
||||
.tensor_map_buffer = tensor_map_buffer.data_ptr(),
|
||||
.tensor_map_a_base = tensor_map_a_base,
|
||||
.tensor_map_b_base = tensor_map_b_base,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
|
||||
|
||||
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
|
||||
#include "epilogue.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
@@ -17,6 +19,7 @@ public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
const std::optional<std::string>& epilogue_type;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
@@ -43,7 +46,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -55,7 +58,8 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
|
||||
get_default_epilogue_type(args.epilogue_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -74,7 +78,8 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
@@ -98,7 +103,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
@@ -111,6 +116,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -170,6 +176,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -230,6 +237,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
|
||||
151
csrc/jit_kernels/impls/smxx_cublaslt.hpp
Normal file
151
csrc/jit_kernels/impls/smxx_cublaslt.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
static auto get_cublaslt_layout(const cudaDataType& type, const int& rows, const int& cols, const int& ld,
|
||||
const std::optional<int>& batch_count = std::nullopt,
|
||||
const std::optional<int>& batch_offset = std::nullopt) {
|
||||
cublasLtMatrixLayout_t layout;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&layout, type, rows, cols, ld));
|
||||
if (batch_count.has_value()) {
|
||||
DG_HOST_ASSERT(batch_offset.has_value());
|
||||
|
||||
const int64_t batch_offset_int64 = batch_offset.value();
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count.value(), sizeof(batch_count.value())));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset_int64, sizeof(batch_offset_int64)));
|
||||
}
|
||||
return layout;
|
||||
}
|
||||
|
||||
static void call_cublaslt_api(const cublasOperation_t& trans_a,
|
||||
const cublasOperation_t& trans_b,
|
||||
const cublasLtMatrixLayout_t& layout_a,
|
||||
const cublasLtMatrixLayout_t& layout_b,
|
||||
const cublasLtMatrixLayout_t& layout_d,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const bool& accumulate) {
|
||||
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
cudaDataType_t scale_type = CUDA_R_32F;
|
||||
const int& math_sms = device_runtime->get_num_sms();
|
||||
bool fp8_fast_accumulate = false;
|
||||
|
||||
// Operation description
|
||||
cublasLtMatmulDesc_t desc;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescCreate(&desc, compute_type, scale_type));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms)));
|
||||
if (a.scalar_type() == torch::kFloat8_e4m3fn)
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate)));
|
||||
|
||||
// Get cuBLASLt handle, workspace, and stream
|
||||
const auto& handle = device_runtime->get_cublaslt_handle();
|
||||
const auto& workspace = device_runtime->get_cublaslt_workspace();
|
||||
const auto& workspace_bytes = workspace.nbytes();
|
||||
const auto& stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Algorithm selection
|
||||
cublasLtMatmulPreference_t pref;
|
||||
cublasLtMatmulHeuristicResult_t heuristic;
|
||||
int num_heuristic_results = 0;
|
||||
uint32_t reduction_scheme_mask = CUBLASLT_REDUCTION_SCHEME_NONE | CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&pref));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_bytes, sizeof(workspace_bytes)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
&reduction_scheme_mask, sizeof(reduction_scheme_mask)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, layout_a, layout_b, layout_d, layout_d,
|
||||
pref, 1, &heuristic, &num_heuristic_results));
|
||||
DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM");
|
||||
|
||||
// Call: D = alpha * (A @ B) + beta * C
|
||||
const float& alpha = 1.0, beta = accumulate ? 1.0 : 0.0;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle
|
||||
desc, // Operation description
|
||||
&alpha, // Alpha
|
||||
b.data_ptr(), layout_a, // A
|
||||
a.data_ptr(), layout_b, // B
|
||||
&beta, // Beta
|
||||
d.data_ptr(), layout_d, // C
|
||||
d.data_ptr(), layout_d, // D
|
||||
&heuristic.algo, // Algorithm
|
||||
workspace.data_ptr(), workspace_bytes, // Workspace
|
||||
stream)); // Stream
|
||||
|
||||
// Free memory
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceDestroy(pref));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_a));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_b));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_d));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescDestroy(desc));
|
||||
}
|
||||
|
||||
static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs,
|
||||
const std::optional<torch::Tensor>& acc,
|
||||
const torch::Tensor& out,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major) {
|
||||
const auto& trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
const auto& trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
// TODO: remove this
|
||||
if (acc.has_value()) {
|
||||
if (acc->data_ptr() == out.data_ptr()) {
|
||||
DG_HOST_ASSERT(acc->sizes() == out.sizes() and acc->strides() == out.strides());
|
||||
} else {
|
||||
out.copy_(acc.value());
|
||||
}
|
||||
}
|
||||
|
||||
// Matrix layouts
|
||||
const auto& cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type());
|
||||
const auto& cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type());
|
||||
const auto& cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type());
|
||||
const auto& layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0))
|
||||
: get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1));
|
||||
const auto& layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0))
|
||||
: get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1));
|
||||
const auto& layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0));
|
||||
|
||||
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, acc.has_value());
|
||||
}
|
||||
|
||||
|
||||
static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
|
||||
const int& b, const int& h, const int& r, const int& d) {
|
||||
const auto& m = d, n = b, k = r;
|
||||
const auto& trans_a = CUBLAS_OP_T;
|
||||
const auto& trans_b = CUBLAS_OP_N;
|
||||
|
||||
// Matrix layouts
|
||||
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0));
|
||||
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
|
||||
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
|
||||
|
||||
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
|
||||
}
|
||||
|
||||
|
||||
static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
|
||||
const int& b, const int& h, const int& r, const int& d) {
|
||||
const auto& m = r, n = b, k = d;
|
||||
const auto& trans_a = CUBLAS_OP_N;
|
||||
const auto& trans_b = CUBLAS_OP_N;
|
||||
|
||||
// Matrix layouts
|
||||
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0));
|
||||
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
|
||||
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
|
||||
|
||||
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -1,6 +1,8 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "apis/attention.hpp"
|
||||
#include "apis/einsum.hpp"
|
||||
#include "apis/gemm.hpp"
|
||||
#include "apis/layout.hpp"
|
||||
#include "apis/runtime.hpp"
|
||||
@@ -13,6 +15,8 @@
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "DeepGEMM C++ library";
|
||||
|
||||
deep_gemm::attention::register_apis(m);
|
||||
deep_gemm::einsum::register_apis(m);
|
||||
deep_gemm::gemm::register_apis(m);
|
||||
deep_gemm::layout::register_apis(m);
|
||||
deep_gemm::runtime::register_apis(m);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <exception>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
@@ -72,4 +73,16 @@ do { \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_CUBLASLT_CHECK
|
||||
#define DG_CUBLASLT_CHECK(cmd) \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != CUBLAS_STATUS_SUCCESS) { \
|
||||
std::ostringstream ss; \
|
||||
ss << static_cast<int>(e) << " (" << cublasGetStatusString(e) << ")"; \
|
||||
throw DGException("cuBLASLt", __FILE__, __LINE__, ss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <filesystem>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "exception.hpp"
|
||||
#include "format.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
@@ -65,7 +69,10 @@ static std::filesystem::path make_dirs(const std::filesystem::path& path) {
|
||||
// OK if existed
|
||||
std::error_code capture;
|
||||
const bool& created = std::filesystem::create_directories(path, capture);
|
||||
DG_HOST_ASSERT(created or capture.value() == 0);
|
||||
if (not (created or capture.value() == 0)) {
|
||||
DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}",
|
||||
path.c_str(), created, capture.value()));
|
||||
}
|
||||
if (created and get_env<int>("DG_JIT_DEBUG"))
|
||||
printf("Create directory: %s\n", path.c_str());
|
||||
return path;
|
||||
|
||||
@@ -25,15 +25,22 @@ from deep_gemm_cpp import (
|
||||
# FP8 GEMMs
|
||||
fp8_gemm_nt, fp8_gemm_nn,
|
||||
fp8_gemm_tn, fp8_gemm_tt,
|
||||
fp8_gemm_nt_skip_head_mid,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
m_grouped_fp8_gemm_nn_contiguous,
|
||||
m_grouped_fp8_gemm_nt_masked,
|
||||
k_grouped_fp8_gemm_nt_contiguous,
|
||||
k_grouped_fp8_gemm_tn_contiguous,
|
||||
# BF16 GEMMs
|
||||
bf16_gemm_nt, bf16_gemm_nn,
|
||||
bf16_gemm_tn, bf16_gemm_tt,
|
||||
m_grouped_bf16_gemm_nt_contiguous,
|
||||
m_grouped_bf16_gemm_nt_masked,
|
||||
# cuBLASLt GEMMs
|
||||
cublaslt_gemm_nt, cublaslt_gemm_nn,
|
||||
cublaslt_gemm_tn, cublaslt_gemm_tt,
|
||||
# Einsum kernels
|
||||
einsum,
|
||||
# Layout kernels
|
||||
transform_sf_into_required_layout
|
||||
)
|
||||
|
||||
27
deep_gemm/include/deep_gemm/common/epilogue_utils.cuh
Normal file
27
deep_gemm/include/deep_gemm/common/epilogue_utils.cuh
Normal file
@@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include <deep_gemm/common/types.hpp>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct EpilogueIdentity {
|
||||
template <uint32_t STORE_BLOCK_N>
|
||||
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
|
||||
return n_idx;
|
||||
}
|
||||
};
|
||||
|
||||
template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
|
||||
struct EpilogueHeadSplits: EpilogueIdentity {
|
||||
template <uint32_t STORE_BLOCK_N>
|
||||
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
|
||||
DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0
|
||||
and kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
|
||||
return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
|
||||
}
|
||||
};
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
} // namespace deep_gemm
|
||||
44
deep_gemm/include/deep_gemm/common/reduction.cuh
Normal file
44
deep_gemm/include/deep_gemm/common/reduction.cuh
Normal file
@@ -0,0 +1,44 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda/std/cstdint>
|
||||
#include <cuda/std/utility>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
|
||||
// Operation functors
|
||||
template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
|
||||
template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
|
||||
template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
|
||||
template <typename T> struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } };
|
||||
template <typename T> struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } };
|
||||
|
||||
// Unified reduction function
|
||||
template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
|
||||
__forceinline__ __device__ T warp_reduce(T value, Op op) {
|
||||
DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
|
||||
kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
|
||||
"Invalid number of lanes");
|
||||
constexpr uint32_t mask = 0xffffffff;
|
||||
if constexpr (kIntergroupReduce) {
|
||||
if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
|
||||
if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
|
||||
if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
|
||||
if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
|
||||
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
|
||||
} else {
|
||||
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
|
||||
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
|
||||
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
|
||||
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
|
||||
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// Convenience aliases
|
||||
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
|
||||
__forceinline__ __device__ T warp_reduce_sum(T value) {
|
||||
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
|
||||
}
|
||||
@@ -22,7 +22,6 @@ static constexpr uint32_t get_num_1d_blocks_per_group() {
|
||||
if (usage < min_usage)
|
||||
min_usage = usage, num_best_blocks = candidate;
|
||||
}
|
||||
|
||||
return num_best_blocks;
|
||||
}
|
||||
|
||||
@@ -33,6 +32,7 @@ template <GemmType kGemmType,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
uint32_t SF_K_ALIGNMENT = 512u, // for k-grouped GEMM only: 128 (SM90 float SF) or 512 (SM100 UE8M0 SF)
|
||||
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
@@ -48,30 +48,40 @@ struct Scheduler {
|
||||
|
||||
// For grouped GEMM
|
||||
int* grouped_layout;
|
||||
uint32_t current_group_idx;
|
||||
uint32_t current_group_idx = 0;
|
||||
// Only used for masked layout
|
||||
uint32_t current_m_cumsum;
|
||||
uint32_t current_m_cumsum = 0;
|
||||
// Only used for k-grouped layout
|
||||
uint32_t current_shape_k, current_num_valid_groups, current_k_cumsum, current_sf_k_cumsum;
|
||||
uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
|
||||
uint32_t next_group_idx, next_shape_k;
|
||||
|
||||
// Only used for k-grouped gemm
|
||||
__device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const {
|
||||
for (; group_idx < kNumGroups; ++ group_idx) {
|
||||
shape_k = __ldg(grouped_layout + group_idx);
|
||||
if (shape_k > 0)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// ReSharper disable once CppPossiblyUninitializedMember
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n,
|
||||
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k,
|
||||
int* grouped_layout = nullptr) {
|
||||
num_m_blocks = ceil_div(shape_m, BLOCK_M);
|
||||
num_n_blocks = ceil_div(shape_n, BLOCK_N);
|
||||
current_shape_k = shape_k;
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
num_blocks = num_m_blocks * num_n_blocks;
|
||||
} else if (kGemmType == GemmType::MGroupedContiguous) {
|
||||
num_blocks = num_m_blocks * num_n_blocks;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::MGroupedMasked) {
|
||||
current_group_idx = current_m_cumsum = 0;
|
||||
this->grouped_layout = grouped_layout;
|
||||
} else if (kGemmType == GemmType::KGroupedContiguous) {
|
||||
current_group_idx = current_num_valid_groups = 0;
|
||||
current_k_cumsum = current_sf_k_cumsum = 0;
|
||||
current_shape_k = __ldg(grouped_layout + current_group_idx);
|
||||
this->grouped_layout = grouped_layout;
|
||||
get_next_k_group(current_group_idx, current_shape_k);
|
||||
next_group_idx = current_group_idx + 1;
|
||||
get_next_k_group(next_group_idx, next_shape_k);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,17 +175,17 @@ struct Scheduler {
|
||||
return false;
|
||||
|
||||
// Within current group
|
||||
if (current_shape_k > 0 and next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
|
||||
if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
|
||||
break;
|
||||
|
||||
// Move to check the next group
|
||||
if (current_shape_k > 0) {
|
||||
current_k_cumsum += current_shape_k;
|
||||
current_sf_k_cumsum += ceil_div(current_shape_k, 512u);
|
||||
current_num_valid_groups ++;
|
||||
}
|
||||
if ((++ current_group_idx) != kNumGroups)
|
||||
current_shape_k = __ldg(grouped_layout + current_group_idx);
|
||||
current_k_cumsum += current_shape_k;
|
||||
current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT);
|
||||
current_num_valid_groups ++;
|
||||
|
||||
current_group_idx = next_group_idx ++;
|
||||
current_shape_k = next_shape_k;
|
||||
get_next_k_group(next_group_idx, next_shape_k);
|
||||
}
|
||||
|
||||
get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx);
|
||||
@@ -197,7 +207,7 @@ struct Scheduler {
|
||||
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
|
||||
if (num_blocks_in_group == 1)
|
||||
return false;
|
||||
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked) {
|
||||
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::KGroupedContiguous) {
|
||||
return true;
|
||||
} else {
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");
|
||||
|
||||
@@ -79,12 +79,24 @@ void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_p
|
||||
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
|
||||
}
|
||||
|
||||
__device__ __forceinline__
|
||||
static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) {
|
||||
return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16;
|
||||
}
|
||||
|
||||
// ReSharper disable once CppNotAllPathsReturnValue
|
||||
template <uint32_t kSwizzleMode>
|
||||
template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, bool kUseBase32, typename dtype_t>
|
||||
constexpr static cute::UMMA::LayoutType to_umma_layout_type() {
|
||||
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
|
||||
kSwizzleMode == 32 or kSwizzleMode == 64 or
|
||||
kSwizzleMode == 128, "Invalid swizzling mode");
|
||||
// A special case
|
||||
if constexpr ((cute::is_same_v<dtype_t, float> and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) {
|
||||
DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base");
|
||||
return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B;
|
||||
}
|
||||
|
||||
// Normal cases
|
||||
if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE;
|
||||
if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE;
|
||||
if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B;
|
||||
@@ -104,10 +116,12 @@ uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, cons
|
||||
return base + (((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
|
||||
}
|
||||
|
||||
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
|
||||
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, bool kUseBase32 = false, typename dtype_t>
|
||||
__device__ __forceinline__
|
||||
cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
|
||||
const uint32_t stride_k = get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
|
||||
const auto& layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
|
||||
const auto& num_non_contiguous = 128 / get_atom_base(layout_type);
|
||||
if constexpr (kMajorMode == cute::UMMA::Major::K) {
|
||||
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
|
||||
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
|
||||
@@ -115,9 +129,9 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
|
||||
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
|
||||
// {SBO, LBO} means the byte stride between atoms on {MN, K}
|
||||
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
|
||||
const uint32_t stride_byte_offset = 8 * BLOCK_K * sizeof(dtype_t);
|
||||
const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
|
||||
const uint32_t leading_byte_offset = 0;
|
||||
return make_smem_desc(to_umma_layout_type<kSwizzleMode>(),
|
||||
return make_smem_desc(layout_type,
|
||||
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
|
||||
stride_byte_offset, leading_byte_offset);
|
||||
} else {
|
||||
@@ -132,11 +146,11 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
|
||||
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
|
||||
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
|
||||
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
|
||||
uint32_t stride_byte_offset = 8 * BLOCK_MN_ATOM * sizeof(dtype_t);
|
||||
uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
|
||||
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
|
||||
if constexpr (kSwizzleMode == 16)
|
||||
swap(stride_byte_offset, leading_byte_offset);
|
||||
return make_smem_desc(to_umma_layout_type<kSwizzleMode>(),
|
||||
return make_smem_desc(layout_type,
|
||||
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
|
||||
stride_byte_offset, leading_byte_offset);
|
||||
}
|
||||
@@ -166,4 +180,81 @@ __device__ __forceinline__ void tcgen05_after_thread_sync() {
|
||||
asm volatile("tcgen05.fence::after_thread_sync;");
|
||||
}
|
||||
|
||||
// UMMA versions with relaxed assertions
|
||||
struct SM100_MMA_F16BF16_SS {
|
||||
__device__ static void
|
||||
fma(uint64_t const& desc_a,
|
||||
uint64_t const& desc_b,
|
||||
uint32_t const& tmem_c,
|
||||
uint32_t const& scale_c,
|
||||
uint64_t const& desc) {
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
|
||||
"}\n"
|
||||
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_MMA_F16BF16_2x1SM_SS {
|
||||
__device__ static void
|
||||
fma(uint64_t const& desc_a,
|
||||
uint64_t const& desc_b,
|
||||
uint32_t const& tmem_c,
|
||||
uint32_t const& scale_c,
|
||||
uint64_t const& desc) {
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t"
|
||||
"}\n"
|
||||
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_MMA_MXF8F6F4_SS {
|
||||
__device__ static void
|
||||
fma(uint64_t const& desc_a,
|
||||
uint64_t const& desc_b,
|
||||
uint32_t const& tmem_c,
|
||||
uint32_t const& scale_c,
|
||||
uint64_t const& desc,
|
||||
uint32_t const& tmem_sfa,
|
||||
uint32_t const& tmem_sfb) {
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
|
||||
"}\n"
|
||||
:
|
||||
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
|
||||
"r"(tmem_sfa), "r"(tmem_sfb));
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_MMA_MXF8F6F4_2x1SM_SS {
|
||||
__device__ static void
|
||||
fma(uint64_t const& desc_a,
|
||||
uint64_t const& desc_b,
|
||||
uint32_t const& tmem_c,
|
||||
uint32_t const& scale_c,
|
||||
uint64_t const& desc,
|
||||
uint32_t const& tmem_sfa,
|
||||
uint32_t const& tmem_sfb) {
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
|
||||
"}\n"
|
||||
:
|
||||
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
|
||||
"r"(tmem_sfa), "r"(tmem_sfb));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace `deep_gemm::sm100`
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/mma_sm90_gmma.hpp>
|
||||
#include <cute/arch/mma_sm90_gmma_ext.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
|
||||
namespace deep_gemm::sm90 {
|
||||
|
||||
template <int N_, typename MMA>
|
||||
@@ -29,6 +33,7 @@ struct FP8MMASelector {
|
||||
|
||||
static constexpr auto select_mma() {
|
||||
using namespace cute::SM90::GMMA;
|
||||
if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN();
|
||||
@@ -93,6 +98,7 @@ struct BF16MMASelector {
|
||||
|
||||
static constexpr auto select_mma() {
|
||||
using namespace cute::SM90::GMMA;
|
||||
if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
@@ -144,6 +150,24 @@ struct SM90_U32x2_STSM_N {
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U32x2_LDSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) {
|
||||
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
|
||||
: "=r"(dst_0), "=r"(dst_1)
|
||||
: "l"(smem_src));
|
||||
}
|
||||
};
|
||||
|
||||
struct SM90_U32x4_LDSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) {
|
||||
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3)
|
||||
: "l"(smem_src));
|
||||
}
|
||||
};
|
||||
|
||||
__forceinline__ __device__ void warpgroup_arrive() {
|
||||
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
|
||||
}
|
||||
@@ -223,4 +247,37 @@ tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
tma_3d_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& crd_2) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
cute::SM90_TMA_LOAD_3D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1, crd_2);
|
||||
}
|
||||
|
||||
// Tensormap related
|
||||
__device__ __forceinline__ void tensor_map_release_cta() {
|
||||
asm volatile ("fence.proxy.tensormap::generic.release.cta;");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) {
|
||||
auto gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
|
||||
asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory");
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) {
|
||||
auto smem_int_desc = static_cast<uint32_t>(__cvta_generic_to_shared(smem_desc));
|
||||
const auto new_int64_addr = reinterpret_cast<uint64_t>(new_addr);
|
||||
asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) {
|
||||
auto smem_int_desc = __cvta_generic_to_shared(smem_desc);
|
||||
asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim));
|
||||
#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 5)))
|
||||
asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride));
|
||||
#else
|
||||
DG_STATIC_ASSERT(false, "Invalid CUDA version")
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace `deep_gemm::sm90`
|
||||
|
||||
@@ -104,6 +104,12 @@ __device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float2 ld_shared(const float2* ptr) {
|
||||
float2 ret;
|
||||
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float4 ld_shared(const float4* ptr) {
|
||||
float4 ret;
|
||||
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr));
|
||||
@@ -126,10 +132,18 @@ __device__ __forceinline__ void st_shared(const float* ptr, float val) {
|
||||
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
|
||||
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
|
||||
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) {
|
||||
asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(ptr), "r"(x), "r"(y));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
|
||||
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(ptr), "r"(x), "r"(y), "r"(z), "r"(w));
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumStages,
|
||||
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
@@ -84,8 +84,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == 0) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
@@ -93,35 +92,31 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
cute::prefetch_tma_descriptor(&tensor_map_c);
|
||||
}
|
||||
|
||||
// Data on shared memory (layout as ordered below)
|
||||
cd_dtype_t* smem_cd[kNumTMAStoreStages];
|
||||
cutlass::bfloat16_t* smem_a[kNumStages];
|
||||
cutlass::bfloat16_t* smem_b[kNumStages];
|
||||
|
||||
// Fill D/A/B pointers
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
|
||||
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
// D/A/B shared memory
|
||||
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
});
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
||||
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
||||
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
||||
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
||||
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
||||
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
|
||||
auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
|
||||
|
||||
// Fill the tensor memory pointer
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1);
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
// Initialize barriers
|
||||
if (threadIdx.x == 0) {
|
||||
if (warp_idx == 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
// Arrive only at the leader CTA
|
||||
@@ -136,11 +131,12 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
// Arrive only at the leader CTA
|
||||
tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads);
|
||||
}
|
||||
if constexpr (kTensorCoreUtilControl < 100)
|
||||
tensor_core_full_barrier->init(1);
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
|
||||
} else if (warp_idx == 2) {
|
||||
// Allocate tensor memory
|
||||
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
@@ -148,100 +144,69 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
uint32_t phase = 0;
|
||||
auto launch_k_iterations = [&](const auto& func) {
|
||||
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
|
||||
const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K);
|
||||
const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages;
|
||||
// Pipeline and TMA phases
|
||||
uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
|
||||
auto advance_pipeline = [&](uint32_t& k_block_idx) {
|
||||
++ k_block_idx;
|
||||
|
||||
// TODO: refactor here
|
||||
if (num_last_stages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1)
|
||||
func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages);
|
||||
} else {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1)
|
||||
func(k_iter, DivisibleK{}, false, num_last_stages);
|
||||
func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1;
|
||||
}
|
||||
};
|
||||
|
||||
auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) {
|
||||
DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2,
|
||||
"Too many epilogue stages, please modify the Python heuristic as well");
|
||||
accum_stage_idx == 0 ? func(0) : func(1);
|
||||
// Flip phases only if reach the next first stage
|
||||
stage_idx = (stage_idx + 1) % kNumStages;
|
||||
phase ^= stage_idx == 0;
|
||||
};
|
||||
|
||||
// Dispatch warps into different roles
|
||||
if (warp_idx == 0) {
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
// TMA load warp
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
||||
#pragma unroll
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
||||
// Wait consumer release
|
||||
empty_barriers[stage_idx]->wait(phase ^ 1);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait(phase ^ 1);
|
||||
// Compute offsets
|
||||
// NOTES: the group is always concatenated with the outer dimension
|
||||
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
|
||||
shape_m, BLOCK_M, m_block_idx);
|
||||
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
|
||||
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
||||
|
||||
// Compute offsets
|
||||
// NOTES: the group is always concatenated with the outer dimension
|
||||
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
|
||||
shape_m, BLOCK_M, m_block_idx);
|
||||
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
|
||||
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
||||
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
||||
// And for all m-grouped GEMMs, A must be K-majored
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
||||
uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
|
||||
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
||||
// And for all m-grouped GEMMs, A must be K-majored
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
||||
uint32_t k_block_idx = k_iter * kNumStages + s;
|
||||
uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
|
||||
// Add 2 CTA offsets
|
||||
if constexpr (kNumMulticast > 1) {
|
||||
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
|
||||
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
||||
}
|
||||
|
||||
// Issue TMAs
|
||||
if (cute::elect_one_sync()) {
|
||||
if constexpr (kMajorA == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx);
|
||||
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
|
||||
}
|
||||
// Arrive at full barriers
|
||||
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
|
||||
if (not is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive(0u);
|
||||
// Add 2 CTA offsets
|
||||
if constexpr (kNumMulticast > 1) {
|
||||
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
|
||||
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait(phase ^ 1);
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive();
|
||||
if (not is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive(0u);
|
||||
// Issue TMAs
|
||||
if constexpr (kMajorA == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx);
|
||||
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx);
|
||||
|
||||
// Arrive at full barriers
|
||||
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
if (is_leader_cta) {
|
||||
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
|
||||
} else {
|
||||
full_barriers[stage_idx]->arrive(0u);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (warp_idx == 1 and is_leader_cta) {
|
||||
// MMA issue warp
|
||||
@@ -268,88 +233,89 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
|
||||
// Wait tensor memory empty barrier arrival
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
||||
// Wait tensor memory empty barrier arrival
|
||||
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// UMMA and empty barrier arrival alias
|
||||
auto umma_arrive = [](const uint64_t* barrier) {
|
||||
if constexpr (kNumMulticast == 1) {
|
||||
cutlass::arch::umma_arrive(barrier);
|
||||
} else {
|
||||
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
||||
}
|
||||
};
|
||||
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
||||
|
||||
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
||||
if (do_tmem_full_arrive)
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
||||
// Wait TMA arrival
|
||||
full_barriers[stage_idx]->wait(phase);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) {
|
||||
auto umma_arrive = [](const uint64_t* barrier) {
|
||||
if constexpr (kNumMulticast == 1) {
|
||||
cutlass::arch::umma_arrive(barrier);
|
||||
} else {
|
||||
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
||||
}
|
||||
};
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[s]));
|
||||
|
||||
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
||||
if (do_tmem_full_arrive)
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_F16BF16_SS, SM100_MMA_F16BF16_2x1SM_SS>;
|
||||
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
||||
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
|
||||
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
|
||||
if (cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA arrival
|
||||
full_barriers[s]->wait(phase);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Let tensor cores relax for lower possibility of frequency drop
|
||||
DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
|
||||
if constexpr (kTensorCoreUtilControl < 100) {
|
||||
constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull;
|
||||
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
|
||||
const auto& start_clock = clock64();
|
||||
if (cute::elect_one_sync())
|
||||
while (clock64() - start_clock < kNumDummyCycles) {}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
using cute_mma_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_MMA_F16BF16_SS <cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>,
|
||||
cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>>;
|
||||
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
||||
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s);
|
||||
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s);
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
|
||||
cute_mma_t::fma(a_desc, b_desc,
|
||||
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
||||
k_iter > 0 or s > 0 or k > 0,
|
||||
runtime_instr_desc);
|
||||
}
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
|
||||
mma_t::fma(a_desc, b_desc,
|
||||
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
||||
k_block_idx > 0 or k > 0,
|
||||
runtime_instr_desc);
|
||||
}
|
||||
|
||||
// Commit to the mbarrier object
|
||||
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
||||
empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait(phase);
|
||||
empty_barrier_arrive(s, false);
|
||||
}
|
||||
});
|
||||
});
|
||||
// Commit to the mbarrier object
|
||||
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
||||
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
|
||||
|
||||
// Let tensor cores relax for lower possibility of frequency drop
|
||||
DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
|
||||
if constexpr (kTensorCoreUtilControl < 100) {
|
||||
// For utilization control
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
|
||||
|
||||
// Wait for last UMMA to be done
|
||||
tensor_core_full_barrier->wait(tensor_core_phase);
|
||||
tensor_core_phase ^= 1;
|
||||
|
||||
// Sleep for certain cycles
|
||||
constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull;
|
||||
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
|
||||
const auto& start_clock = clock64();
|
||||
if (cute::elect_one_sync())
|
||||
while (clock64() - start_clock < kNumDummyCycles) {}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// To safely deconstruct barriers, we need another round of waits
|
||||
const auto& iter_idx = scheduler.current_iter - 1;
|
||||
if (kNumMulticast > 1 and iter_idx >= 0) {
|
||||
const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
|
||||
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
|
||||
}
|
||||
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
|
||||
// Epilogue warp groups
|
||||
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
|
||||
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
|
||||
|
||||
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
||||
@@ -363,129 +329,114 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
||||
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
||||
|
||||
// Share store pipeline between blocks
|
||||
uint32_t tma_stage_idx = 0;
|
||||
auto advance_store_pipeline = [&]() {
|
||||
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
|
||||
};
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
|
||||
// Flush TMA stores
|
||||
// NOTES: for the first store, we have to flush all previous TMA,
|
||||
// as we don't share pipeline stages between two blocks
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
// Wait UMMA arrival
|
||||
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Wait UMMA arrival
|
||||
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
||||
tcgen05_after_thread_sync();
|
||||
// Load from tensor memory into registers, and write shared memory with STSM
|
||||
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
|
||||
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
||||
|
||||
// Load from tensor memory into registers, and write shared memory with STSM
|
||||
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
|
||||
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
||||
|
||||
// Iterate over M waves
|
||||
// Iterate over M waves
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
// Issue every swizzled atom and pipeline STSM and TMA store
|
||||
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
// Issue every swizzled atom and pipeline STSM and TMA store
|
||||
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
||||
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
|
||||
// Wait shared memory to be released
|
||||
if (epilogue_warp_idx == 0)
|
||||
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
|
||||
|
||||
// The pipeline stage
|
||||
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
|
||||
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
||||
|
||||
// Store into shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStores; ++ s) {
|
||||
// Wait shared memory to be released
|
||||
const uint32_t iter_idx = w * kNumStores + s;
|
||||
if (iter_idx >= kNumTMAStoreStages) {
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
}
|
||||
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
||||
|
||||
// The pipeline stage
|
||||
const auto tma_stage_idx = iter_idx % kNumTMAStoreStages;
|
||||
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
|
||||
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
||||
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
||||
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
||||
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleCDMode / 16);
|
||||
|
||||
// Store into shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
||||
// Source and destination memory address
|
||||
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
|
||||
w * BLOCK_N + // Wave offset
|
||||
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
||||
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
||||
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
||||
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
||||
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
||||
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleCDMode / 16);
|
||||
|
||||
// Source and destination memory address
|
||||
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
|
||||
w * BLOCK_N + // Wave offset
|
||||
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
||||
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
||||
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
|
||||
// Load from tensor memory, store into shared memory
|
||||
uint32_t values[kNumElemsPerBankGroup];
|
||||
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
||||
// For FP32 output, read and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
||||
} else {
|
||||
// For BF16 output, read, cast and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3],
|
||||
values[4], values[5], values[6], values[7]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr,
|
||||
cast_into_bf16_and_pack(values[0], values[1]),
|
||||
cast_into_bf16_and_pack(values[2], values[3]),
|
||||
cast_into_bf16_and_pack(values[4], values[5]),
|
||||
cast_into_bf16_and_pack(values[6], values[7]));
|
||||
}
|
||||
}
|
||||
|
||||
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
||||
// NOTES: only the last stage needs to do this
|
||||
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
||||
tcgen05_before_thread_sync();
|
||||
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
if (epilogue_thread_idx == 0) {
|
||||
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
||||
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
||||
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
|
||||
cute::tma_store_arrive();
|
||||
// Load from tensor memory, store into shared memory
|
||||
uint32_t values[kNumElemsPerBankGroup];
|
||||
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
||||
// For FP32 output, read and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
||||
} else {
|
||||
// For BF16 output, read, cast and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3],
|
||||
values[4], values[5], values[6], values[7]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr,
|
||||
cast_into_bf16_and_pack(values[0], values[1]),
|
||||
cast_into_bf16_and_pack(values[2], values[3]),
|
||||
cast_into_bf16_and_pack(values[4], values[5]),
|
||||
cast_into_bf16_and_pack(values[6], values[7]));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
||||
// NOTES: only the last stage needs to do this
|
||||
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
||||
tcgen05_before_thread_sync();
|
||||
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
|
||||
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
|
||||
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
||||
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
||||
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deallocate tensor memory by warp 1
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
if (epilogue_warp_idx == 1)
|
||||
Allocator().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
// To safely deconstruct all barriers, we need a cluster sync
|
||||
// TODO: optimize it by another round of barrier waits
|
||||
if constexpr (kNumMulticast > 1)
|
||||
cute::cluster_sync();
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
|
||||
|
||||
265
deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh
Normal file
265
deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh
Normal file
@@ -0,0 +1,265 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/util/type_traits.hpp>
|
||||
#include <cutlass/arch/barrier.h>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm100_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm100;
|
||||
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kSplitFactor,
|
||||
uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
|
||||
uint32_t kNumStages, uint32_t kNumThreads>
|
||||
__global__ void __launch_bounds__(kNumThreads, 1)
|
||||
sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t LAYOUT_AD_M = 128;
|
||||
constexpr uint32_t kNumTMAStoreStages = 2;
|
||||
|
||||
// Utils
|
||||
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
const auto lane_idx = get_lane_idx();
|
||||
DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
|
||||
DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
|
||||
// Shared memory sizes
|
||||
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode;
|
||||
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
}
|
||||
|
||||
// Real tensor memory size and offsets
|
||||
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_N>();
|
||||
|
||||
// Fill D/A/B
|
||||
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
|
||||
});
|
||||
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
||||
});
|
||||
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
||||
});
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
||||
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
||||
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
||||
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
||||
auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
|
||||
|
||||
// Fill the tensor memory pointer
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 2 + 1);
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
// Initialize barriers
|
||||
if (warp_idx == 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(1);
|
||||
}
|
||||
tmem_full_barrier->init(1);
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (warp_idx == 2) {
|
||||
// Allocate tensor memory
|
||||
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Block indices
|
||||
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
|
||||
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
|
||||
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
|
||||
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
|
||||
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
|
||||
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
|
||||
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
|
||||
|
||||
if (warp_idx == 0) {
|
||||
// TMA load warp
|
||||
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
||||
const auto& stage_idx = s % kNumStages;
|
||||
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
|
||||
|
||||
uint32_t m_idx = BLOCK_M * m_block_idx;
|
||||
uint32_t n_idx = BLOCK_N * n_block_idx;
|
||||
uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
|
||||
uint32_t k_idx = sk_idx % SHAPE_K;
|
||||
uint32_t s_idx = sk_idx / SHAPE_K;
|
||||
|
||||
// Issue TMAs
|
||||
if (cute::elect_one_sync()) {
|
||||
tma_copy<BLOCK_K, BLOCK_M, kSwizzleABMode, 1>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
|
||||
tma_copy<BLOCK_K, BLOCK_N, kSwizzleABMode, 1>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
|
||||
}
|
||||
|
||||
// Arrive at full barriers
|
||||
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
if (cute::elect_one_sync())
|
||||
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
|
||||
}
|
||||
} else if (warp_idx == 1) {
|
||||
// MMA issue warp
|
||||
// NOTES: only the leader CTA will do this
|
||||
// Make instruction descriptor
|
||||
constexpr uint32_t UMMA_M = LAYOUT_AD_M;
|
||||
constexpr uint32_t UMMA_N = BLOCK_N;
|
||||
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
|
||||
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
|
||||
|
||||
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
||||
auto a_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
|
||||
auto b_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
|
||||
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
||||
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
||||
|
||||
// Checks for MMA instructions
|
||||
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
||||
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
||||
"Invalid MMA instruction shape");
|
||||
|
||||
// Wait tensor memory empty barrier arrival
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Launch MMAs
|
||||
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
||||
// Wait TMA arrival
|
||||
const auto& stage_idx = s % kNumStages;
|
||||
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
||||
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx);
|
||||
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx);
|
||||
if (cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
a_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(a_desc_base_lo, 0, k * UMMA_K);
|
||||
b_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
|
||||
}
|
||||
}
|
||||
|
||||
// Commit to the mbarrier object
|
||||
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
||||
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
||||
}
|
||||
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
|
||||
}
|
||||
|
||||
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
||||
// i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
|
||||
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
||||
if (warp_idx == 2)
|
||||
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
|
||||
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float);
|
||||
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
||||
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
||||
|
||||
// Wait UMMA arrival
|
||||
tmem_full_barrier->wait(0);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Load from tensor memory into registers, and write shared memory with STSM
|
||||
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
||||
|
||||
// Issue every swizzled atom and pipeline STSM and TMA store
|
||||
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStores; ++ s) {
|
||||
// Wait shared memory to be released
|
||||
if (s >= kNumTMAStoreStages) {
|
||||
if (warp_idx == 0 and cute::elect_one_sync())
|
||||
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
cutlass::arch::NamedBarrier(kNumThreads).sync();
|
||||
}
|
||||
|
||||
// The pipeline stage
|
||||
const auto tma_stage_idx = s % kNumTMAStoreStages;
|
||||
const auto m_idx = m_block_idx * BLOCK_M;
|
||||
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
||||
|
||||
// Store into shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
||||
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
||||
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
||||
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleCDMode / 16);
|
||||
|
||||
// Source and destination memory address
|
||||
uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
||||
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
||||
warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
|
||||
// Load from tensor memory, store into shared memory
|
||||
uint32_t values[kNumElemsPerBankGroup];
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
||||
}
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumThreads).sync();
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
}
|
||||
|
||||
// Deallocate tensor memory by warp 1
|
||||
// NOTES: warp 0 is doing TMA stores
|
||||
if (warp_idx == 1)
|
||||
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
||||
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
|
||||
#include <deep_gemm/common/epilogue_utils.cuh>
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm100_utils.cuh>
|
||||
@@ -17,11 +18,12 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumStages,
|
||||
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t>
|
||||
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
|
||||
typename epilogue_type_t>
|
||||
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
||||
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
@@ -96,8 +98,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == 0) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
||||
@@ -107,30 +108,25 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
cute::prefetch_tma_descriptor(&tensor_map_c);
|
||||
}
|
||||
|
||||
// Data on shared memory (layout as ordered below)
|
||||
cd_dtype_t* smem_cd[kNumTMAStoreStages];
|
||||
cutlass::float_e4m3_t* smem_a[kNumStages];
|
||||
cutlass::float_e4m3_t* smem_b[kNumStages];
|
||||
uint32_t* smem_sfa[kNumStages];
|
||||
uint32_t* smem_sfb[kNumStages];
|
||||
// D/A/B shared memory
|
||||
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
});
|
||||
|
||||
// Fill D/A/B pointers
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
|
||||
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fill SFA/SFB
|
||||
// SFA/SFB shared memory
|
||||
auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_sfa[i] = reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
|
||||
smem_sfb[i] = reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
|
||||
}
|
||||
auto smem_sfa = PatternVisitor([=](const uint32_t& i) {
|
||||
return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_sfb = PatternVisitor([=](const uint32_t& i) {
|
||||
return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
|
||||
});
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer +
|
||||
@@ -148,7 +144,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
// Initialize barriers
|
||||
if (threadIdx.x == 0) {
|
||||
if (warp_idx == 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
// Arrive at all CTAs
|
||||
@@ -166,9 +162,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
|
||||
} else if (warp_idx == 2) {
|
||||
// Allocate tensor memory
|
||||
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
@@ -176,108 +171,75 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
uint32_t phase = 0;
|
||||
auto launch_k_iterations = [&](const auto& func) {
|
||||
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
|
||||
const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K);
|
||||
const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages;
|
||||
// Pipeline and TMA phases
|
||||
uint32_t stage_idx = 0, phase = 0;
|
||||
auto advance_pipeline = [&](uint32_t& k_block_idx) {
|
||||
++ k_block_idx;
|
||||
|
||||
// TODO: refactor here
|
||||
if (num_last_stages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1)
|
||||
func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages);
|
||||
} else {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1)
|
||||
func(k_iter, DivisibleK{}, false, num_last_stages);
|
||||
func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1;
|
||||
}
|
||||
};
|
||||
|
||||
auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) {
|
||||
DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2,
|
||||
"Too many epilogue stages, please modify the Python heuristic as well");
|
||||
accum_stage_idx == 0 ? func(0) : func(1);
|
||||
// Flip phases only if reach the next first stage
|
||||
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
|
||||
phase ^= stage_idx == 0;
|
||||
};
|
||||
|
||||
// Dispatch warps into different roles
|
||||
if (warp_idx == 0) {
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
// TMA load warp
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
||||
// Wait consumer release
|
||||
empty_barriers[stage_idx]->wait(phase ^ 1);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait(phase ^ 1);
|
||||
// Compute offsets
|
||||
// NOTES: the group is always concatenated with the outer dimension
|
||||
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
|
||||
shape_m, BLOCK_M, m_block_idx);
|
||||
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
|
||||
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
||||
|
||||
// Compute offsets
|
||||
// NOTES: the group is always concatenated with the outer dimension
|
||||
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
|
||||
shape_m, BLOCK_M, m_block_idx);
|
||||
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
|
||||
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
||||
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
||||
// And for all m-grouped GEMMs, A must be K-majored
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
||||
uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
|
||||
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
||||
// And for all m-grouped GEMMs, A must be K-majored
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
||||
uint32_t k_block_idx = k_iter * kNumStages + s;
|
||||
uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
|
||||
// Add 2 CTA offsets
|
||||
if constexpr (kNumMulticast > 1) {
|
||||
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
|
||||
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
||||
}
|
||||
|
||||
// Issue TMAs
|
||||
if (cute::elect_one_sync()) {
|
||||
if constexpr (kMajorA == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx);
|
||||
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
|
||||
}
|
||||
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
|
||||
// Issue SFA and SFB TMAs at certain stages
|
||||
// No swizzling, so one TMA for one SF is enough
|
||||
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
||||
tma_copy<BLOCK_M, 1, 0, 1>(&tensor_map_sfa, full_barriers[s], smem_sfa[s], m_block_idx * BLOCK_M,
|
||||
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
|
||||
tma_copy<BLOCK_N, 1, 0, 1>(&tensor_map_sfb, full_barriers[s], smem_sfb[s], n_block_idx * BLOCK_N,
|
||||
scheduler.template get_global_idx<true, KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
|
||||
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
// Arrive at full barriers
|
||||
if (cute::elect_one_sync())
|
||||
full_barriers[s]->arrive_and_expect_tx(num_arrival_bytes);
|
||||
// Add 2 CTA offsets
|
||||
if constexpr (kNumMulticast > 1) {
|
||||
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
|
||||
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait(phase ^ 1);
|
||||
if (cute::elect_one_sync())
|
||||
full_barriers[s]->arrive();
|
||||
// Issue TMAs
|
||||
if constexpr (kMajorA == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx);
|
||||
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx);
|
||||
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
|
||||
// Issue SFA and SFB TMAs at certain stages
|
||||
// No swizzling, so one TMA for one SF is enough
|
||||
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0) {
|
||||
tma_copy<BLOCK_M, 1, 0, 1>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M,
|
||||
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
|
||||
tma_copy<BLOCK_N, 1, 0, 1>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N,
|
||||
scheduler.template get_global_idx<true, KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
|
||||
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
|
||||
}
|
||||
});
|
||||
|
||||
// Arrive at full barriers
|
||||
full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
|
||||
}
|
||||
}
|
||||
} else if (warp_idx == 1 and is_leader_cta) {
|
||||
// MMA issue warp
|
||||
@@ -307,101 +269,93 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
|
||||
// Wait tensor memory empty barrier arrival
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
||||
// Wait tensor memory empty barrier arrival
|
||||
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
|
||||
auto umma_arrive = [](const uint64_t* barrier) {
|
||||
if constexpr (kNumMulticast == 1) {
|
||||
cutlass::arch::umma_arrive(barrier);
|
||||
} else {
|
||||
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
||||
}
|
||||
};
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
|
||||
|
||||
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
||||
if (do_tmem_full_arrive)
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
||||
// Wait TMA and SF-transpose arrival
|
||||
with_sf_full_barriers[stage_idx]->wait(phase);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) {
|
||||
auto umma_arrive = [](const uint64_t* barrier) {
|
||||
if constexpr (kNumMulticast == 1) {
|
||||
cutlass::arch::umma_arrive(barrier);
|
||||
} else {
|
||||
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
||||
}
|
||||
};
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[s]));
|
||||
|
||||
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
||||
if (do_tmem_full_arrive)
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
// Do SF copy at certain stages
|
||||
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
|
||||
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
||||
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
|
||||
|
||||
// SFA and SFB copy
|
||||
// TODO: process shared memory descriptor by addition
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA and SF-transpose arrival
|
||||
with_sf_full_barriers[s]->wait(phase);
|
||||
tcgen05_after_thread_sync();
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Do SF copy at certain stages
|
||||
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
|
||||
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
||||
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
|
||||
|
||||
// SFA and SFB copy
|
||||
// TODO: process shared memory descriptor by addition
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfa[s] + i * kNumUTCCPAlignedElems;
|
||||
replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfb[s] + i * kNumUTCCPAlignedElems;
|
||||
replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
using cute_mma_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_MMA_MXF8F6F4_SS <cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
||||
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>,
|
||||
cute::SM100_MMA_MXF8F6F4_2x1SM_SS<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
||||
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>>;
|
||||
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
|
||||
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s);
|
||||
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s);
|
||||
// Issue UMMA in the leader CTA
|
||||
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_MXF8F6F4_SS, SM100_MMA_MXF8F6F4_2x1SM_SS>;
|
||||
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
|
||||
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
|
||||
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
|
||||
if (cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
|
||||
cute_mma_t::fma(a_desc, b_desc,
|
||||
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
||||
k_iter > 0 or s > 0 or k > 0,
|
||||
runtime_instr_desc,
|
||||
kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
|
||||
kTmemStartColOfSFB);
|
||||
}
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
|
||||
mma_t::fma(a_desc, b_desc,
|
||||
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
||||
k_block_idx > 0 or k > 0,
|
||||
runtime_instr_desc,
|
||||
kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
|
||||
kTmemStartColOfSFB);
|
||||
}
|
||||
|
||||
// Commit to the mbarrier object
|
||||
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
||||
empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
with_sf_full_barriers[s]->wait(phase);
|
||||
empty_barrier_arrive(s, false);
|
||||
}
|
||||
});
|
||||
});
|
||||
// Commit to the mbarrier object
|
||||
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
||||
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
|
||||
}
|
||||
}
|
||||
|
||||
// To safely deconstruct barriers, we need another round of waits
|
||||
const auto& iter_idx = scheduler.current_iter - 1;
|
||||
if (kNumMulticast > 1 and iter_idx >= 0) {
|
||||
const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
|
||||
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
|
||||
}
|
||||
} else if (warp_idx == 2) {
|
||||
// UTCCP transposer
|
||||
@@ -418,43 +372,30 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
};
|
||||
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
||||
// Wait TMA arrival
|
||||
full_barriers[stage_idx]->wait(phase);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA arrival
|
||||
full_barriers[s]->wait(phase);
|
||||
|
||||
// Transpose for UTCCP at certain stages
|
||||
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
|
||||
utccp_required_smem_warp_transpose(smem_sfa[s] + i * kNumUTCCPAlignedElems);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
|
||||
utccp_required_smem_warp_transpose(smem_sfb[s] + i * kNumUTCCPAlignedElems);
|
||||
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
}
|
||||
|
||||
// Arrive
|
||||
with_sf_full_barriers[s]->arrive(0u);
|
||||
// Transpose for UTCCP at certain stages
|
||||
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
|
||||
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
|
||||
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
|
||||
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait(phase);
|
||||
with_sf_full_barriers[s]->arrive(0u);
|
||||
}
|
||||
});
|
||||
// Arrive
|
||||
with_sf_full_barriers[stage_idx]->arrive(0u);
|
||||
}
|
||||
}
|
||||
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
|
||||
// Epilogue warp groups
|
||||
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
|
||||
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
|
||||
|
||||
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
||||
@@ -468,129 +409,113 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
||||
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
||||
|
||||
// Share store pipeline between blocks
|
||||
uint32_t tma_stage_idx = 0;
|
||||
auto advance_store_pipeline = [&]() {
|
||||
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
|
||||
};
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
|
||||
// Flush TMA stores
|
||||
// NOTES: for the first store, we have to flush all previous TMA,
|
||||
// as we don't share pipeline stages between two blocks
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
// Wait UMMA arrival
|
||||
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Wait UMMA arrival
|
||||
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
||||
tcgen05_after_thread_sync();
|
||||
// Load from tensor memory into registers, and write shared memory with STSM
|
||||
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
|
||||
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
||||
|
||||
// Load from tensor memory into registers, and write shared memory with STSM
|
||||
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
|
||||
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
||||
|
||||
// Iterate over M waves
|
||||
// Iterate over M waves
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
// Issue every swizzled atom and pipeline STSM and TMA store
|
||||
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
// Issue every swizzled atom and pipeline STSM and TMA store
|
||||
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
||||
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
|
||||
// Wait shared memory to be released
|
||||
if (epilogue_warp_idx == 0)
|
||||
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
|
||||
|
||||
// The pipeline stage
|
||||
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
|
||||
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
|
||||
|
||||
// Store into shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStores; ++ s) {
|
||||
// Wait shared memory to be released
|
||||
const uint32_t iter_idx = w * kNumStores + s;
|
||||
if (iter_idx >= kNumTMAStoreStages) {
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
}
|
||||
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
||||
|
||||
// The pipeline stage
|
||||
const auto tma_stage_idx = iter_idx % kNumTMAStoreStages;
|
||||
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
|
||||
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
||||
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
||||
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
||||
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleCDMode / 16);
|
||||
|
||||
// Store into shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
||||
// Source and destination memory address
|
||||
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
|
||||
w * BLOCK_N + // Wave offset
|
||||
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
||||
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
||||
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
||||
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
||||
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
||||
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleCDMode / 16);
|
||||
|
||||
// Source and destination memory address
|
||||
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
|
||||
w * BLOCK_N + // Wave offset
|
||||
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
||||
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
||||
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
|
||||
// Load from tensor memory, store into shared memory
|
||||
uint32_t values[kNumElemsPerBankGroup];
|
||||
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
||||
// For FP32 output, read and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
||||
} else {
|
||||
// For BF16 output, read, cast and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3],
|
||||
values[4], values[5], values[6], values[7]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr,
|
||||
cast_into_bf16_and_pack(values[0], values[1]),
|
||||
cast_into_bf16_and_pack(values[2], values[3]),
|
||||
cast_into_bf16_and_pack(values[4], values[5]),
|
||||
cast_into_bf16_and_pack(values[6], values[7]));
|
||||
}
|
||||
}
|
||||
|
||||
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
||||
// NOTES: only the last stage needs to do this
|
||||
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
||||
tcgen05_before_thread_sync();
|
||||
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
if (epilogue_thread_idx == 0) {
|
||||
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
||||
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
||||
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
|
||||
cute::tma_store_arrive();
|
||||
// Load from tensor memory, store into shared memory
|
||||
uint32_t values[kNumElemsPerBankGroup];
|
||||
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
||||
// For FP32 output, read and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
||||
} else {
|
||||
// For BF16 output, read, cast and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3],
|
||||
values[4], values[5], values[6], values[7]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr,
|
||||
cast_into_bf16_and_pack(values[0], values[1]),
|
||||
cast_into_bf16_and_pack(values[2], values[3]),
|
||||
cast_into_bf16_and_pack(values[4], values[5]),
|
||||
cast_into_bf16_and_pack(values[6], values[7]));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
||||
// NOTES: only the last stage needs to do this
|
||||
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
||||
tcgen05_before_thread_sync();
|
||||
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
||||
}
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
|
||||
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
|
||||
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
||||
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
||||
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deallocate tensor memory by warp 1
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
if (epilogue_warp_idx == 1)
|
||||
Allocator().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
// To safely deconstruct all barriers, we need a cluster sync
|
||||
// TODO: optimize it by another round of barrier waits
|
||||
if constexpr (kNumMulticast > 1)
|
||||
cute::cluster_sync();
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <deep_gemm/common/epilogue_utils.cuh>
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm100_utils.cuh>
|
||||
@@ -22,7 +23,8 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
GemmType kGemmType, typename cd_dtype_t>
|
||||
GemmType kGemmType, typename cd_dtype_t,
|
||||
typename epilogue_type_t>
|
||||
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
||||
sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
@@ -88,8 +90,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == 0) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
@@ -133,7 +134,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
// Initialize barriers
|
||||
if (threadIdx.x == 0) {
|
||||
if (warp_idx == 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
// Arrive at all CTAs
|
||||
@@ -149,9 +150,8 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
|
||||
} else if (warp_idx == 2) {
|
||||
// Allocate tensor memory
|
||||
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
@@ -174,7 +174,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
||||
|
||||
// Register configurations
|
||||
constexpr uint32_t kNumNonEpilogueRegisters = 64;
|
||||
@@ -435,7 +435,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
// as we don't share pipeline stages between two blocks
|
||||
if (epilogue_thread_idx_in_warpgroup == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
|
||||
cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx);
|
||||
|
||||
// Write shared memory
|
||||
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
||||
@@ -449,13 +449,13 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
if (s >= kNumTMAStoreStages) {
|
||||
if (epilogue_thread_idx_in_warpgroup == 0)
|
||||
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
|
||||
cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx);
|
||||
}
|
||||
|
||||
// The pipeline stage
|
||||
const auto tma_stage_idx = s % kNumTMAStoreStages;
|
||||
const auto m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_m, BLOCK_M, m_block_idx);
|
||||
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
||||
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
|
||||
const auto local_smem_cd = smem_cd[tma_stage_idx] + epilogue_warpgroup_idx * STORE_BLOCK_M * STORE_BLOCK_N;
|
||||
|
||||
// Store into shared memory
|
||||
@@ -502,7 +502,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
|
||||
cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx);
|
||||
if (epilogue_thread_idx_in_warpgroup == 0) {
|
||||
cute::SM90_TMA_STORE_2D::copy(
|
||||
&tensor_map_d, local_smem_cd,
|
||||
@@ -512,10 +512,6 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
}
|
||||
}
|
||||
|
||||
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
|
||||
if (epilogue_thread_idx_in_warpgroup == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
|
||||
// Deallocate tensor memory by warp 1
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
if (epilogue_warp_idx == 1)
|
||||
|
||||
@@ -25,7 +25,8 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
uint32_t kNumSMs, GemmType kGemmType>
|
||||
uint32_t kNumSMs, GemmType kGemmType,
|
||||
typename cd_dtype_t>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
||||
sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
@@ -44,7 +45,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(cd_dtype_t);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
|
||||
|
||||
@@ -55,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
const uint32_t lane_idx = get_lane_idx();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
@@ -67,7 +68,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
|
||||
__nv_bfloat16* smem_a[kNumStages];
|
||||
__nv_bfloat16* smem_b[kNumStages];
|
||||
|
||||
@@ -91,7 +92,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
@@ -99,7 +100,6 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
@@ -125,14 +125,14 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
if (warp_idx >= kNumMathThreads / 32) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) {
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type) {
|
||||
@@ -203,7 +203,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
}
|
||||
};
|
||||
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type) {
|
||||
@@ -237,11 +237,10 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival at the last warpgroup wave
|
||||
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
@@ -256,7 +255,6 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
|
||||
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
||||
DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
|
||||
DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
|
||||
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
|
||||
"Unaligned TMA store or too many TMA store instructions");
|
||||
@@ -265,60 +263,76 @@ sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
// Wait last TMA store to be finished
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
|
||||
|
||||
// Write back to shared memory using STSM and issue TMA stores
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
||||
if constexpr (std::is_same_v<cd_dtype_t, cutlass::bfloat16_t>) {
|
||||
// Write back to shared memory using STSM and issue TMA stores
|
||||
DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// Swizzle or padding into the correct address
|
||||
uint8_t* smem_ptr = nullptr;
|
||||
if constexpr (kSwizzleDMode > 0) {
|
||||
// Calculate the swizzling atom offset and in-atom offset
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// Swizzle or padding into the correct address
|
||||
uint8_t* smem_ptr = nullptr;
|
||||
if constexpr (kSwizzleDMode > 0) {
|
||||
// Calculate the swizzling atom offset and in-atom offset
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
||||
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
|
||||
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
||||
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleDMode / 16);
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
|
||||
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
||||
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleDMode / 16);
|
||||
|
||||
// Add back into the base pointer
|
||||
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
||||
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
|
||||
m_offset * kSwizzleDMode + // Wave offset
|
||||
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
} else {
|
||||
// No swizzling, just padding
|
||||
// TODO: support more cases
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
|
||||
// Add back into the base pointer
|
||||
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
||||
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
|
||||
m_offset * kSwizzleDMode + // Wave offset
|
||||
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
} else {
|
||||
// No swizzling
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
|
||||
}
|
||||
|
||||
// NOTES: only 16 lanes' addresses are used
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
||||
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
||||
smem_ptr
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Use `st.shared` if STSM is not available
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
||||
auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2);
|
||||
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
|
||||
st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
|
||||
}
|
||||
|
||||
// NOTES: only 16 lanes' addresses are used
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
||||
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
||||
smem_ptr
|
||||
);
|
||||
}
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
// TODO: compatible with FP32 output
|
||||
constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked;
|
||||
DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
||||
|
||||
173
deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh
Normal file
173
deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh
Normal file
@@ -0,0 +1,173 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kSplitFactor,
|
||||
uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
||||
sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
float *d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Types
|
||||
using WGMMA = typename BF16MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
|
||||
|
||||
// Configs
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_idx();
|
||||
DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
|
||||
DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
|
||||
DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (warp_idx == 0 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
// Fill shared memory pointers
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
|
||||
});
|
||||
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
||||
});
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
||||
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
||||
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
||||
|
||||
// Initialize barriers
|
||||
if (warp_idx == 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumMathThreads);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
__syncthreads();
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr uint32_t kNumTMARegisters = 40;
|
||||
constexpr uint32_t kNumMathRegisters = 232;
|
||||
|
||||
// Block indices
|
||||
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
|
||||
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
|
||||
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
|
||||
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
|
||||
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
|
||||
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
|
||||
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
|
||||
|
||||
if (warp_idx >= kNumMathThreads / 32) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
// Persistently schedule over blocks
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
||||
// Wait consumer release
|
||||
const auto& stage_idx = s % kNumStages;
|
||||
empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
|
||||
|
||||
auto& full_barrier = *full_barriers[stage_idx];
|
||||
const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
|
||||
const uint32_t& k_idx = sk_idx % SHAPE_K;
|
||||
const uint32_t& s_idx = sk_idx / SHAPE_K;
|
||||
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
||||
float accum[WGMMA::kNumAccum] = {0};
|
||||
|
||||
// Launch MMAs
|
||||
for (uint32_t s = 0; s < num_total_stages; ++ s) {
|
||||
// Wait TMA arrivals
|
||||
const auto& stage_idx = s % kNumStages;
|
||||
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, 1);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival at the last warpgroup wave
|
||||
empty_barriers[stage_idx]->arrive();
|
||||
}
|
||||
|
||||
const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
|
||||
const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
if (col + i * 8 >= SHAPE_N)
|
||||
break;
|
||||
if (row < SHAPE_M) {
|
||||
atomicAdd(reinterpret_cast<float2*>(d + (row + 0) * SHAPE_N + col + i * 8),
|
||||
make_float2(accum[i * 4 + 0], accum[i * 4 + 1]));
|
||||
}
|
||||
if (row + 8 < SHAPE_M) {
|
||||
atomicAdd(reinterpret_cast<float2*>(d + (row + 8) * SHAPE_N + col + i * 8),
|
||||
make_float2(accum[i * 4 + 2], accum[i * 4 + 3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
@@ -1,3 +1,348 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: add implement
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
GemmType kGemmType, typename cd_dtype_t>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
||||
sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
||||
int* grouped_layout,
|
||||
cute::TmaDescriptor* tensor_map_buffer,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a_base,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b_base,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads");
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
||||
|
||||
// Overwrite shape constants if the compiler gives
|
||||
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
||||
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
||||
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0);
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
|
||||
static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
|
||||
static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
|
||||
DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
|
||||
|
||||
// Configs
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = threadIdx.x % 32;
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a_base);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b_base);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sfb);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Tensor maps on shared and global memory
|
||||
auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * i);
|
||||
});
|
||||
auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * (2 + i));
|
||||
});
|
||||
auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; });
|
||||
auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; });
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
|
||||
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
|
||||
});
|
||||
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
|
||||
});
|
||||
constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
||||
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
|
||||
});
|
||||
auto smem_sfb = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
|
||||
});
|
||||
|
||||
// Barriers on shared memory
|
||||
constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
|
||||
auto full_barriers = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
|
||||
});
|
||||
auto empty_barriers = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
|
||||
});
|
||||
|
||||
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
|
||||
// Load tensormap A/B to shared memory
|
||||
if constexpr (kGemmType == GemmType::KGroupedContiguous) {
|
||||
*smem_tensor_map_a[0] = tensor_map_a_base;
|
||||
*smem_tensor_map_a[1] = tensor_map_a_base;
|
||||
*smem_tensor_map_b[0] = tensor_map_b_base;
|
||||
*smem_tensor_map_b[1] = tensor_map_b_base;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
|
||||
// even with TMA multicast disabled, we want to make the behavior aligned
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// Pipeline unroll control
|
||||
constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages);
|
||||
|
||||
// Register reconfigurations (more math registers are needed with unrolling)
|
||||
constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
|
||||
constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
|
||||
|
||||
// TMA and MMA pipeline
|
||||
const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
|
||||
};
|
||||
uint32_t iter_idx = 0;
|
||||
|
||||
if (warp_idx >= kNumMathThreads / 32) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base;
|
||||
const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
|
||||
uint32_t last_group_idx = kNumGroups, sum_k = 0;
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Assign TMA multicast number into A and B
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
||||
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
|
||||
const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
|
||||
const uint32_t& m_idx = m_block_idx * BLOCK_M;
|
||||
const uint32_t& n_idx = n_block_idx * BLOCK_N;
|
||||
|
||||
if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) {
|
||||
const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1;
|
||||
const uint32_t& next_stage_idx = stage_idx ^ 1;
|
||||
last_group_idx = scheduler.current_group_idx;
|
||||
|
||||
// Prepare next tensor map
|
||||
sum_k += scheduler.current_shape_k;
|
||||
if (scheduler.next_group_idx < kNumGroups) {
|
||||
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + sum_k * shape_m);
|
||||
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + sum_k * shape_n);
|
||||
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
|
||||
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
|
||||
*(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]);
|
||||
*(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]);
|
||||
tensor_map_release_cta();
|
||||
}
|
||||
|
||||
// Get current tensor map
|
||||
if (scheduler.current_num_valid_groups > 0) {
|
||||
tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]);
|
||||
tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]);
|
||||
current_tensor_map_a = gmem_tensor_map_a[stage_idx];
|
||||
current_tensor_map_b = gmem_tensor_map_b[stage_idx];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll kNumPipelineUnrolls
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
|
||||
// Wait consumer release
|
||||
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
|
||||
empty_barriers[stage_idx]->wait(phase ^ 1);
|
||||
|
||||
// Issue TMA
|
||||
auto& full_barrier = *full_barriers[stage_idx];
|
||||
const uint32_t& k_idx = k_block_idx * BLOCK_K;
|
||||
const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
|
||||
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier), smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_sfb, reinterpret_cast<uint64_t*>(&full_barrier), smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
|
||||
tma_copy(current_tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier), smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
|
||||
tma_copy(current_tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier), smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
|
||||
}
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s) {
|
||||
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
|
||||
empty_barriers[stage_idx]->wait(phase ^ 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
||||
const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
|
||||
const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
|
||||
const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
|
||||
const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
|
||||
const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K);
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
|
||||
float2 scales_b[WGMMA::kNumAccum / 4];
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](uint32_t s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
|
||||
}
|
||||
};
|
||||
|
||||
#pragma unroll kNumPipelineUnrolls
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
|
||||
// Wait TMA arrivals
|
||||
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
|
||||
full_barriers[stage_idx]->wait(phase);
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0);
|
||||
auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1);
|
||||
|
||||
// Read B scales
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
|
||||
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival
|
||||
empty_barrier_arrive(stage_idx);
|
||||
|
||||
// Promote with scales
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
const float &scale_b_0 = scales_b[i].x;
|
||||
const float &scale_b_1 = scales_b[i].y;
|
||||
final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
|
||||
final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
|
||||
final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
|
||||
final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
|
||||
// Flush previous stores
|
||||
if (warp_idx % 4 == 0 and cute::elect_one_sync())
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
|
||||
|
||||
// Store to D shared memory
|
||||
const auto& smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
|
||||
const auto& smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
|
||||
st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
if (warp_idx % 4 == 0 and cute::elect_one_sync()) {
|
||||
cute::SM90_TMA_REDUCE_ADD_2D::copy(
|
||||
&tensor_map_d, smem_d_0, n_block_idx * BLOCK_N,
|
||||
current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0);
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include <deep_gemm/common/epilogue_utils.cuh>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
@@ -18,15 +19,15 @@ namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd>
|
||||
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) {
|
||||
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
|
||||
__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
|
||||
if (num_former_iters == kNumFormerIters) {
|
||||
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
|
||||
func(cute::Int<kNumFormerIters>{});
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (kNumFormerIters + kGap <= kEnd)
|
||||
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
|
||||
dispatch_num_former_iters<kNumFormerIters + kGap, kGap, kEnd>(num_former_iters, func);
|
||||
}
|
||||
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
@@ -36,7 +37,8 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
uint32_t kNumSMs, GemmType kGemmType>
|
||||
uint32_t kNumSMs, GemmType kGemmType,
|
||||
typename epilogue_type_t>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
||||
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
@@ -69,14 +71,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
const uint32_t& smem_sfb_size = align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
|
||||
const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_idx();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
||||
@@ -90,35 +90,26 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_fp8_e4m3* smem_a[kNumStages];
|
||||
__nv_fp8_e4m3* smem_b[kNumStages];
|
||||
float* smem_sfa[kNumStages];
|
||||
float* smem_sfb;
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
Barrier* empty_barriers[kNumStages];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
smem_sfa[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE);
|
||||
}
|
||||
smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE));
|
||||
auto smem_a = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_b = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
});
|
||||
constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
||||
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE);
|
||||
});
|
||||
auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE);
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
|
||||
auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
|
||||
|
||||
// Initialize barriers
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
|
||||
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
|
||||
// even with TMA multicast disabled, we want to make the behavior aligned
|
||||
#pragma unroll
|
||||
@@ -128,107 +119,72 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
struct SkipComputation {};
|
||||
struct NotSkipComputation {};
|
||||
auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) {
|
||||
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
||||
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
|
||||
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
||||
|
||||
// NOTES: for too-many branches (> 5), we disable this optimization
|
||||
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
|
||||
outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) {
|
||||
if (skip_computation) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type);
|
||||
} else if (shape_k % kFullKOfAllStages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
|
||||
} else {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
|
||||
func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
|
||||
}
|
||||
}, func, kShouldOptimize ? num_former_iters : 0);
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr uint32_t kNumTMARegisters = 40;
|
||||
constexpr uint32_t kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// Pipeline and TMA phases
|
||||
uint32_t stage_idx = 0, phase = 0;
|
||||
auto advance_pipeline = [&](uint32_t& k_block_idx) {
|
||||
++ k_block_idx;
|
||||
|
||||
// Flip phases only if reach the next first stage
|
||||
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
|
||||
phase ^= stage_idx == 0;
|
||||
};
|
||||
|
||||
if (warp_idx >= kNumMathThreads / 32) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) {
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
// Assign TMA multicast number into A and B
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
||||
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
||||
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
||||
// Wait consumer release
|
||||
empty_barriers[stage_idx]->wait(phase ^ 1);
|
||||
|
||||
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
||||
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
// Issue TMA A
|
||||
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
||||
auto& full_barrier = *full_barriers[stage_idx];
|
||||
const uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
|
||||
num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_block_idx),
|
||||
num_tma_multicast_a);
|
||||
|
||||
// Issue TMA A
|
||||
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
|
||||
num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_sfa[s], m_block_idx * BLOCK_M,
|
||||
scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_idx / BLOCK_K),
|
||||
num_tma_multicast_a);
|
||||
|
||||
// Issue TMA B
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
|
||||
num_tma_multicast_b);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
}, false, 0);
|
||||
// Issue TMA B
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
|
||||
num_tma_multicast_b);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
|
||||
}
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
|
||||
for (uint32_t i = 0; i < kNumStages; advance_pipeline(i))
|
||||
empty_barriers[stage_idx]->wait(phase ^ 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -239,6 +195,11 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
||||
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
|
||||
|
||||
auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
|
||||
auto b_desc = make_smem_desc(smem_b[0], 1);
|
||||
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
|
||||
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Decide the number of scales B to load
|
||||
@@ -259,7 +220,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
|
||||
st_shared(smem_sfb + i, __ldg(local_sfb + i));
|
||||
}
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
|
||||
|
||||
// Accumulation for WGMMA or CUDA promotion
|
||||
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2);
|
||||
@@ -267,90 +228,96 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](uint32_t s) {
|
||||
auto empty_barrier_arrive = [&]() {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void();
|
||||
} else {
|
||||
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void();
|
||||
}
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) {
|
||||
constexpr bool kSkipComputation = cute::is_same_v<decltype(skip_type), SkipComputation>;
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages);
|
||||
// Skip useless computations
|
||||
if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
|
||||
// The compiler must know the dynamic variable `num_former_iters`'s real value
|
||||
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
|
||||
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
|
||||
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_sfb + k_iter * kNumStages + s), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_sfb + k_iter * kNumStages + s + shape_k_scales);
|
||||
// Dispatch `num_former_iters` and launch MMAs
|
||||
dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
|
||||
#pragma unroll 8
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
||||
const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
|
||||
const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
|
||||
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
// Read B scales
|
||||
float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1;
|
||||
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales);
|
||||
|
||||
// TODO: remove some useless computation for unaligned Ms
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
// Wait TMA arrivals
|
||||
full_barriers[stage_idx]->wait(phase);
|
||||
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_sfa[s] + r_0 + m_offset);
|
||||
auto scale_a_1 = ld_shared(smem_sfa[s] + r_1 + m_offset);
|
||||
// TODO: remove some useless computation for unaligned Ms
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
// Read A scales
|
||||
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
|
||||
auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0 + m_offset);
|
||||
auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1 + m_offset);
|
||||
|
||||
// Notify barrier arrival at the last warpgroup wave
|
||||
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
||||
empty_barrier_arrive(s);
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
|
||||
b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
|
||||
WGMMA::wgmma(a_desc, b_desc, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Promote with scales
|
||||
// NOTES: making it as predicates is very important for performance, comparing to two loops
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
// Notify barrier arrival at the last warpgroup wave
|
||||
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
||||
empty_barrier_arrive();
|
||||
|
||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
// Promote with scales
|
||||
// NOTES: making it as predicates is very important for performance, comparing to two loops
|
||||
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
|
||||
float scale_0_1, scale_1_1;
|
||||
if constexpr (not kMustUseUniformedScaleB)
|
||||
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
|
||||
|
||||
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
|
||||
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
|
||||
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
|
||||
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
|
||||
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
|
||||
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
});
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
|
||||
full_barriers[stage_idx]->wait(phase);
|
||||
empty_barrier_arrive();
|
||||
}
|
||||
}, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters);
|
||||
}
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||
@@ -364,7 +331,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
// Wait last TMA store to be finished
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
|
||||
|
||||
// Write back to shared memory using STSM and issue TMA stores
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
@@ -413,7 +380,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
}
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
// TODO: compatible with FP32 output
|
||||
@@ -423,7 +390,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
||||
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||
epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset),
|
||||
scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
with suppress():
|
||||
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
|
||||
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) if not using_nsys else None
|
||||
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
@@ -112,10 +112,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
||||
is_tuple = isinstance(kernel_names, tuple)
|
||||
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
@@ -136,6 +135,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
||||
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
|
||||
total_num += int(num_str)
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
kernel_times.append(total_time / total_num if total_num > 0 else 0)
|
||||
|
||||
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
||||
|
||||
@@ -16,13 +16,16 @@ def ceil_to_ue8m0(x: torch.Tensor):
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
padded_n = align(n, 128)
|
||||
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
|
||||
x_padded[:, :n] = x
|
||||
x_view = x_padded.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf
|
||||
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
|
||||
|
||||
|
||||
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -54,4 +57,4 @@ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -
|
||||
sf = x_amax / 448.0
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled, sf.squeeze()
|
||||
return x_scaled, sf.squeeze()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import enum
|
||||
import random
|
||||
import torch
|
||||
from typing import Generator, Tuple, List
|
||||
from typing import Generator, List
|
||||
|
||||
from deep_gemm.utils import (
|
||||
align, ceil_div,
|
||||
@@ -11,7 +11,6 @@ from deep_gemm.utils import (
|
||||
|
||||
|
||||
class KernelType(enum.Enum):
|
||||
# For SM100 GEMMs
|
||||
Kernel1D1D = 0
|
||||
Kernel1D2D = 1
|
||||
KernelNoSF = 2
|
||||
@@ -48,62 +47,87 @@ def get_ue8m0_usage(kernel_type: KernelType) -> bool:
|
||||
return kernel_type.is_1d1d()
|
||||
|
||||
|
||||
def get_kernel_types(use_bf16: bool = False) -> tuple:
|
||||
if use_bf16:
|
||||
def get_kernel_types(dtype: torch.dtype) -> tuple:
|
||||
if dtype == torch.bfloat16:
|
||||
return (KernelType.KernelNoSF, )
|
||||
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)
|
||||
|
||||
# TODO: SM100 1D2D kernels are going to be deprecated
|
||||
# But if you want to test it, please use:
|
||||
# `(KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)`
|
||||
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, )
|
||||
|
||||
|
||||
def get_out_dtype() -> tuple:
|
||||
return (torch.bfloat16, ) if get_arch_major() == 9 else (torch.bfloat16, torch.float)
|
||||
def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator:
|
||||
for major_a in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor):
|
||||
for major_b in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor):
|
||||
if major_a.is_mn_major() and not allow_a_mn_major:
|
||||
continue
|
||||
if major_b.is_mn_major() and not allow_b_mn_major:
|
||||
continue
|
||||
yield major_a, major_b
|
||||
|
||||
|
||||
def get_major_ab(freeze_a: bool) -> tuple:
|
||||
# TODO: test other major-ness for SM90 BF16 GEMMs
|
||||
if get_arch_major() == 9:
|
||||
return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), )
|
||||
if freeze_a:
|
||||
return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor)
|
||||
return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor), \
|
||||
(MajorTypeAB.MNMajor, MajorTypeAB.KMajor), (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
|
||||
def enumerate_normal(dtype: torch.dtype) -> Generator:
|
||||
assert dtype in (torch.float8_e4m3fn, torch.bfloat16)
|
||||
|
||||
fp32_output_nk = [(256, 7168), (129280, 7168)]
|
||||
bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]
|
||||
m_fwd_list, m_bwd_list = [128, 4096], [4096, ]
|
||||
nk_list = bf16_output_nk
|
||||
|
||||
# Only BF16 GEMM needs FP32 outputs
|
||||
if dtype == torch.bfloat16:
|
||||
nk_list += fp32_output_nk
|
||||
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
# Forward
|
||||
for m in m_fwd_list:
|
||||
for n, k in nk_list:
|
||||
out_dtype = torch.float if (n, k) in fp32_output_nk else torch.bfloat16
|
||||
yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype
|
||||
|
||||
# TODO: support BF16 SM90 MN-major kernels
|
||||
if dtype == torch.bfloat16 and get_arch_major() == 9:
|
||||
continue
|
||||
|
||||
# Backward
|
||||
for m in m_bwd_list:
|
||||
for n, k in nk_list:
|
||||
override_major = MajorTypeAB.MNMajor
|
||||
override_kernel_type = kernel_type
|
||||
if get_arch_major() == 9 and dtype == torch.float8_e4m3fn:
|
||||
override_major = MajorTypeAB.KMajor
|
||||
override_kernel_type = KernelType.Kernel1D1D
|
||||
yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad
|
||||
yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad
|
||||
yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad
|
||||
|
||||
|
||||
def enumerate_normal(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
for m in (128, 4096):
|
||||
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]:
|
||||
for major_a, major_b in get_major_ab(False):
|
||||
for out_dtype in get_out_dtype():
|
||||
for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True):
|
||||
yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype
|
||||
|
||||
|
||||
def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator:
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)):
|
||||
for major_a, major_b in get_major_ab(True):
|
||||
for major_a, major_b in get_major_ab(False, get_arch_major() > 9):
|
||||
yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b
|
||||
|
||||
|
||||
def enumerate_m_grouped_masked() -> Generator:
|
||||
def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator:
|
||||
max_m = 4096
|
||||
for kernel_type in get_kernel_types():
|
||||
for kernel_type in get_kernel_types(dtype):
|
||||
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
|
||||
for n, k in ((4096, 7168), (7168, 2048), ):
|
||||
yield kernel_type, num_groups, max_m, m, n, k
|
||||
|
||||
|
||||
def enumerate_k_grouped_contiguous():
|
||||
# TODO: support SM90 kernels
|
||||
if get_arch_major() == 9:
|
||||
return []
|
||||
|
||||
# Only K-major is supported for SM90
|
||||
major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 \
|
||||
else (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
|
||||
# Must with FP32 accumulation and 1D1D kernels
|
||||
for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64
|
||||
( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32
|
||||
(16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16
|
||||
ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
|
||||
yield num_groups, m, n, ks, expected_k_per_group
|
||||
yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group
|
||||
|
||||
|
||||
def enumerate_sf_layout():
|
||||
@@ -134,6 +158,7 @@ def enumerate_transpose():
|
||||
def generate_normal(m: int, n: int, k: int,
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB,
|
||||
accumulate: bool, out_dtype: torch.dtype,
|
||||
kernel_type: KernelType,
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
@@ -147,7 +172,9 @@ def generate_normal(m: int, n: int, k: int,
|
||||
b = b if major_b.is_k_major() else b.T.contiguous().T
|
||||
return a, b, c, d, ref_d
|
||||
|
||||
a_fp8, b_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0), per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
|
||||
b_fp8 = per_token_cast_to_fp8(b, use_ue8m0=use_ue8m0) if kernel_type.is_1d1d() and accumulate \
|
||||
else per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1])
|
||||
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1])
|
||||
return a_fp8, b_fp8, c, d, ref_d
|
||||
@@ -214,7 +241,7 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group:
|
||||
return a_fp8, b_fp8, masked_m, d, ref_d
|
||||
|
||||
|
||||
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool):
|
||||
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], use_ue8m0: bool):
|
||||
assert get_mk_alignment_for_contiguous_layout() % 128 == 0
|
||||
k = sum(ks)
|
||||
|
||||
@@ -232,4 +259,20 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int]
|
||||
|
||||
a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0)
|
||||
b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
|
||||
# Transpose for K Major A/B
|
||||
if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor):
|
||||
a, sfa = a_fp8
|
||||
b, sfb = b_fp8
|
||||
new_a = torch.empty((sum(ks) * m, ), dtype=a.dtype, device=a.device)
|
||||
new_b = torch.empty((sum(ks) * n, ), dtype=b.dtype, device=b.device)
|
||||
prefix = 0
|
||||
for K in ks:
|
||||
new_a[prefix * m : (prefix + K) * m] = a[prefix : prefix + K, ].T.flatten()
|
||||
new_b[prefix * n : (prefix + K) * n] = b[prefix : prefix + K, ].T.flatten()
|
||||
prefix += K
|
||||
a_fp8, b_fp8 = (new_a, sfa.T), (new_b, sfb.T)
|
||||
else:
|
||||
assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
|
||||
|
||||
return k, a_fp8, b_fp8, c, d, ref_d
|
||||
|
||||
64
tests/test_attention.py
Normal file
64
tests/test_attention.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import random
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm.testing import bench_kineto, calc_diff, count_bytes
|
||||
from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8
|
||||
|
||||
from generators import get_arch_major, generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB
|
||||
|
||||
|
||||
def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]):
|
||||
left, mid, right = head_splits
|
||||
m, n = d.shape
|
||||
assert n % (left + right) == 0
|
||||
num_heads = n // (left + right)
|
||||
|
||||
# Split and insert padding tensor
|
||||
d = d.view(m, num_heads, -1)
|
||||
d_left = d[:, :, :left]
|
||||
d_right = d[:, :, -right:]
|
||||
|
||||
d_mid = torch.zeros((m, num_heads, mid), dtype=d.dtype, device=d.device)
|
||||
return torch.cat([d_left, d_mid, d_right], dim=2).view(m, -1)
|
||||
|
||||
|
||||
def test_gemm_skip_head_mid() -> None:
|
||||
print('Testing GEMM skip head mid:')
|
||||
head_splits = (128, 64, 128)
|
||||
|
||||
major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor
|
||||
out_dtype, accumulate = torch.bfloat16, False
|
||||
|
||||
for kernel_type in get_kernel_types(dtype=torch.float8_e4m3fn):
|
||||
for m in (128, 4096):
|
||||
for n, k in [(32768, 512), (8192, 512)]:
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
|
||||
a, b, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
||||
d = apply_skip_head_mid(d, head_splits)
|
||||
ref_d = apply_skip_head_mid(ref_d, head_splits)
|
||||
|
||||
deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {kernel_opt}, {diff:.5f}'
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast),
|
||||
'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
test_gemm_skip_head_mid()
|
||||
@@ -7,6 +7,7 @@ from deep_gemm.testing import (
|
||||
calc_diff, count_bytes
|
||||
)
|
||||
from generators import (
|
||||
get_arch_major,
|
||||
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, generate_normal,
|
||||
generate_m_grouped_contiguous, generate_m_grouped_masked
|
||||
)
|
||||
@@ -14,14 +15,18 @@ from generators import (
|
||||
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
for _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(use_bf16=True):
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16):
|
||||
# TODO: support accumulation for SM90 BF16 GEMM
|
||||
if get_arch_major() == 9 and accumulate:
|
||||
continue
|
||||
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
acc_opt = f'acc={int(accumulate)}'
|
||||
|
||||
for test_alias in (False, True):
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True)
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True)
|
||||
func_name = f'bf16_gemm_{major_opt.lower() if test_alias else "nt"}'
|
||||
if test_alias:
|
||||
a = a if major_a.is_k_major() else a.T
|
||||
@@ -31,28 +36,22 @@ def test_gemm() -> None:
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.0001, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, '
|
||||
f'{diff:.5f}, alias={test_alias}')
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True)
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True)
|
||||
|
||||
cublas_t = 0
|
||||
t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True)
|
||||
if accumulate == 0 and out_dtype == torch.bfloat16:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cublas_t = bench_kineto(lambda: a @ b.T, 'nvjet', suppress_kineto_output=True)
|
||||
except Exception:
|
||||
pass
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:5.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
|
||||
f'{cublas_t / t:.2f}x cuBLAS')
|
||||
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing m-grouped contiguous GEMM:')
|
||||
|
||||
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(use_bf16=True):
|
||||
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
|
||||
@@ -85,7 +84,7 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing m-grouped masked GEMM:')
|
||||
|
||||
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
|
||||
for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked():
|
||||
for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.bfloat16):
|
||||
# Test correctness
|
||||
for i in range(10):
|
||||
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True)
|
||||
@@ -111,6 +110,27 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
print()
|
||||
|
||||
|
||||
def test_cublaslt_gemm() -> None:
|
||||
print('Testing cuBLASLt GEMM:')
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
acc_opt = f'acc={int(accumulate)}'
|
||||
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True)
|
||||
deep_gemm.cublaslt_gemm_nt(a, b, d, c=c)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 5e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})'
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), 'nvjet', suppress_kineto_output=True,)
|
||||
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:5.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
@@ -121,5 +141,9 @@ if __name__ == '__main__':
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
# TODO: support SM100
|
||||
if get_arch_major() == 9:
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
|
||||
test_cublaslt_gemm()
|
||||
|
||||
85
tests/test_einsum.py
Normal file
85
tests/test_einsum.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import random
|
||||
import torch
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm.testing import (
|
||||
bench, bench_kineto,
|
||||
calc_diff, count_bytes
|
||||
)
|
||||
|
||||
|
||||
def test_bmk_bnk_mn() -> None:
|
||||
print('Testing "bmk, bnk -> mn":')
|
||||
for s in (129, 4096, 8192):
|
||||
for m, n, k in [(128, 384, 128), (256, 256, 256), (384, 128, 384)]:
|
||||
for dtype in (torch.float, torch.bfloat16):
|
||||
a = torch.randn((s, m, k), dtype=torch.bfloat16, device='cuda')
|
||||
b = torch.randn((s, n, k), dtype=torch.bfloat16, device='cuda')
|
||||
d = torch.randn((m, n), dtype=dtype, device='cuda')
|
||||
c = d if dtype == torch.float else None
|
||||
|
||||
# Test correctness
|
||||
ref_d = (c if dtype == torch.float else 0) + torch.bmm(a.float(), b.float().mT).sum(0)
|
||||
deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c)
|
||||
assert calc_diff(d, ref_d) < 1e-5
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c), 'bmn_bnk_mn_gemm_impl', suppress_kineto_output=True)
|
||||
print(f' > Perf (b={s:4.0f}, {m=}, {n=}, {k=}, {"FP32" if dtype == torch.float else "BF16"}): ',
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * s * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b) + (d.numel() * 4)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_bhr_hdr_bhd():
|
||||
print('Testing "bhr, hdr -> bhd":')
|
||||
for b in (128, 4096, 8192):
|
||||
for h, r, d in [(128, 512, 128)]:
|
||||
x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16)
|
||||
fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16)
|
||||
y = fy[:, :, :r]
|
||||
ref_z = torch.einsum('bhr,hdr->bhd', x, y)
|
||||
z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16)
|
||||
deep_gemm.einsum('bhr,hdr->bhd', x, y, z)
|
||||
assert calc_diff(z, ref_z) < 1e-10
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z), 'nvjet', suppress_kineto_output=True)
|
||||
print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ',
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | '
|
||||
f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_bhd_hdr_bhr():
|
||||
print('Testing "bhd, hdr -> bhr":')
|
||||
for b in (128, 4096, 8192):
|
||||
for h, r, d in [(128, 512, 128)]:
|
||||
x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16)
|
||||
fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16)
|
||||
y = fy[:, :, :r]
|
||||
ref_z = torch.einsum('bhd,hdr->bhr', x, y)
|
||||
z = torch.empty((b, h, r), device='cuda', dtype=torch.bfloat16)
|
||||
deep_gemm.einsum('bhd,hdr->bhr', x, y, z)
|
||||
assert calc_diff(z, ref_z) < 1e-10
|
||||
|
||||
t = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z), 'nvjet', suppress_kineto_output=True)
|
||||
print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ',
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * b * h * r * d / t / 1e12:.0f} TFLOPS | '
|
||||
f'{count_bytes((x, y, z)) / t / 1e9:.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_bmk_bnk_mn()
|
||||
test_bhr_hdr_bhd()
|
||||
test_bhd_hdr_bhr()
|
||||
@@ -10,7 +10,7 @@ from deep_gemm.testing import (
|
||||
)
|
||||
|
||||
from generators import (
|
||||
KernelType, get_ue8m0_usage,
|
||||
KernelType, get_arch_major, get_ue8m0_usage,
|
||||
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
|
||||
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
|
||||
)
|
||||
@@ -18,7 +18,7 @@ from generators import (
|
||||
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal():
|
||||
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
@@ -26,42 +26,35 @@ def test_gemm() -> None:
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
recipe = (1, 1, 128) if kernel_type.is_1d1d() and accumulate else None
|
||||
|
||||
for test_alias in (False, True):
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0)
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
||||
func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}'
|
||||
if test_alias:
|
||||
a = a if major_a.is_k_major() else (a[0].T, a[1].T)
|
||||
b = b if major_b.is_k_major() else (b[0].T, b[1].T)
|
||||
assert a[0].is_contiguous() and b[0].is_contiguous()
|
||||
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, '
|
||||
f'{diff:.5f}, alias={test_alias}')
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0)
|
||||
|
||||
# Test launch overhead
|
||||
launch_start_t = time.time_ns()
|
||||
deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
launch_end_t = time.time_ns()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s')
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
|
||||
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe),
|
||||
'fp8_gemm', suppress_kineto_output=True)
|
||||
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:4.0f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
|
||||
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing m-grouped contiguous GEMM:')
|
||||
|
||||
for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous():
|
||||
for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
@@ -86,7 +79,7 @@ def test_m_grouped_gemm_contiguous() -> None:
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}): '
|
||||
print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
|
||||
@@ -97,7 +90,7 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing m-grouped masked GEMM:')
|
||||
|
||||
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
|
||||
for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked():
|
||||
for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.float8_e4m3fn):
|
||||
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
|
||||
use_ue8m0 = get_ue8m0_usage(kernel_type)
|
||||
disable_ue8m0_cast = not use_ue8m0
|
||||
@@ -130,26 +123,31 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
def test_k_grouped_gemm_contiguous() -> None:
|
||||
print('Testing k-grouped contiguous GEMM:')
|
||||
|
||||
for num_groups, m, n, ks, expected_k_per_group in enumerate_k_grouped_contiguous():
|
||||
k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \
|
||||
else deep_gemm.k_grouped_fp8_gemm_tn_contiguous
|
||||
for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous():
|
||||
use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D)
|
||||
|
||||
for test_empty_groups in (False, True):
|
||||
new_ks = copy.deepcopy(ks)
|
||||
if test_empty_groups:
|
||||
if test_empty_groups and len(ks) > 1:
|
||||
new_ks[random.randint(0, num_groups - 1)] = 0
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, new_ks, use_ue8m0=use_ue8m0)
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0)
|
||||
new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda')
|
||||
deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, new_ks, new_ks_tensor, c=c)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {i=}, {diff:.5f}'
|
||||
k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c)
|
||||
|
||||
do_check = True
|
||||
if do_check:
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}'
|
||||
|
||||
# Test performance
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, ks, use_ue8m0=use_ue8m0)
|
||||
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0)
|
||||
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=c)
|
||||
k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): '
|
||||
|
||||
Reference in New Issue
Block a user