Make various updates and fixes (#198)

This commit is contained in:
Ray Wang
2025-09-25 16:19:07 +08:00
committed by GitHub
parent 79f48ee15a
commit 3f71de7aa9
45 changed files with 3281 additions and 1060 deletions

77
csrc/apis/attention.hpp Normal file
View File

@@ -0,0 +1,77 @@
#pragma once
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
#include "layout.hpp"
namespace deep_gemm::attention {
static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::tuple<int, int, int> &head_splits,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
if (fp8_requires_k_major()) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
}
// D must be N-major
check_major_type_cd(d);
// Type and shape checks
const auto& [m , k ] = get_shape<2>(a.first);
const auto& [n , k_] = get_shape<2>(b.first);
const auto& [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Check head splits and N
const auto& [left, mid, right] = head_splits;
DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid);
// Do nothing if the problem is empty
if (m == 0)
return;
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
const auto& epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) {
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
static void register_apis(pybind11::module_& m) {
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"),
py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
}
} // namespace deep_gemm::attention

115
csrc/apis/einsum.hpp Normal file
View File

@@ -0,0 +1,115 @@
#pragma once
#include <pybind11/pybind11.h>
#include <torch/python.h>
#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/layout.hpp"
#include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/smxx_cublaslt.hpp"
namespace deep_gemm::einsum {
static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d,
const std::optional<torch::Tensor>& c) {
// Currently FP32 only support the accumulated expression
if (d.scalar_type() == torch::kFloat) {
DG_HOST_ASSERT(c->data_ptr() == d.data_ptr() and c->sizes() == d.sizes() and c->strides() == d.strides());
} else {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(not c.has_value());
const auto& workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32));
DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(),
c10::cuda::getCurrentCUDAStream()));
bmk_bnk_mn(a, b, workspace, workspace);
// This line has an implicit FP32-to-BF16 casting
d.copy_(workspace);
return;
}
DG_HOST_ASSERT(a.is_contiguous());
DG_HOST_ASSERT(b.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
const auto& [s , m, k ] = get_shape<3>(a);
const auto& [s_, n, k_] = get_shape<3>(b);
DG_HOST_ASSERT(s == s_ and k == k_);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else if (arch_major == 10) {
sm100_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) {
const auto& [b , h , r ] = get_shape<3>(A);
const auto& [h_, d , r_] = get_shape<3>(B);
const auto& [b_, h__, d_] = get_shape<3>(D);
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d);
}
static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) {
const auto& [b , h , d ] = get_shape<3>(A);
const auto& [h_, d_ , r ] = get_shape<3>(B);
const auto& [b_, h__, r_] = get_shape<3>(D);
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d);
}
static void einsum(const std::string& expr,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c) {
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
if (c.has_value()) {
DG_HOST_ASSERT(c->scalar_type() == torch::kFloat);
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
}
// Some hardcoded Einstein sum kernels
// TODO: support any expression
// TODO: canonicalize expression
if (expr == "bmk,bnk->mn") {
bmk_bnk_mn(a, b, d, c);
} else if (expr == "bhr,hdr->bhd") {
DG_HOST_ASSERT(not c.has_value());
bhr_hdr_bhd(a, b, d);
} else if (expr == "bhd,hdr->bhr") {
DG_HOST_ASSERT(not c.has_value());
bhd_hdr_bhr(a, b, d);
} else {
DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr));
}
}
static void register_apis(pybind11::module_& m) {
m.def("einsum", &einsum,
py::arg("expr"), py::arg("a"), py::arg("b"),
py::arg("d"), py::arg("c") = std::nullopt);
}
} // namespace deep_gemm::einsum

View File

@@ -1,5 +1,6 @@
#pragma once
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
@@ -52,13 +53,18 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
if (std::get<1>(recipe.value()) == 1) {
sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else {
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
}
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
@@ -261,6 +267,60 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
}
}
static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
const std::optional<torch::Tensor>& c,
const std::tuple<int, int, int>& recipe,
const std::string& compiled_dims) {
// Must be 1D1D kernel
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& sum_mk = a.first.numel();
const auto& sum_nk = b.first.numel();
int sum_k = 0;
for (const auto& k: ks)
sum_k += k;
DG_HOST_ASSERT(sum_mk == m * sum_k);
DG_HOST_ASSERT(sum_nk == n * sum_k);
// Contiguity checks
DG_HOST_ASSERT(a.first.is_contiguous());
DG_HOST_ASSERT(b.first.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
if (c.has_value()) {
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
DG_HOST_ASSERT(c.value().is_contiguous());
}
// Do nothing if empty
if (std::accumulate(ks.begin(), ks.end(), 0) == 0)
return;
// Transform SF with padding
const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
// Allocate tensormap buffer
// `4` means the double buffering for both A and B operands (2 * 2)
const auto& num_sms = device_runtime->get_num_sms();
const auto& tensor_map_buffer = torch::empty({num_sms * 4 * static_cast<int>(sizeof(CUtensorMap))},
a.first.options().dtype(torch::kByte));
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void bf16_gemm_nt(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
@@ -403,6 +463,43 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T
}
}
static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a);
const auto& major_b = get_major_type_ab(b);
// Type and shape checks
const auto& [m , k ] = get_shape<2>(a);
const auto& [n , k_] = get_shape<2>(b);
const auto& [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
if (c.has_value())
DG_HOST_ASSERT(c.value().scalar_type() == d.scalar_type());
// Do nothing if the problem is empty
if (m == 0 or n == 0)
return;
cublaslt_gemm(a, b, c, d, m, n, k, major_a, major_b);
}
static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
cublaslt_gemm_nt(a, b.transpose(0, 1), d, c);
}
static void cublaslt_gemm_tn(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
cublaslt_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c);
}
static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
cublaslt_gemm_nt(a.transpose(0, 1), b, d, c);
}
static void register_apis(pybind11::module_& m) {
// FP8 GEMMs
m.def("fp8_gemm_nt", &fp8_gemm_nt,
@@ -442,6 +539,11 @@ static void register_apis(pybind11::module_& m) {
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 1, 128),
py::arg("compiled_dims") = "mn");
m.def("k_grouped_fp8_gemm_nt_contiguous", &k_grouped_fp8_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 1, 128),
py::arg("compiled_dims") = "mn");
// BF16 GEMMs
m.def("bf16_gemm_nt", &bf16_gemm_nt,
@@ -466,6 +568,16 @@ static void register_apis(pybind11::module_& m) {
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
// cuBLASLt GEMMs
m.def("cublaslt_gemm_nt", &cublaslt_gemm_nt,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
m.def("cublaslt_gemm_nn", &cublaslt_gemm_nn,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
m.def("cublaslt_gemm_tn", &cublaslt_gemm_tn,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
m.def("cublaslt_gemm_tt", &cublaslt_gemm_tt,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
}
} // namespace deep_gemm::gemm

View File

@@ -56,14 +56,14 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
// FP32 on SM90
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
DG_HOST_UNREACHABLE("Unimplemented");
return get_mn_major_tma_aligned_tensor(sf);
// FP32 on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
// INT on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
if (sf.scalar_type() == torch::kInt and arch_major == 10)
DG_HOST_UNREACHABLE("Unimplemented");
DG_HOST_UNREACHABLE("Unknown cases");

View File

@@ -1,9 +1,16 @@
// GEMM kernels
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
// Einsum kernels
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
// Layout kernels
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;

View File

@@ -1,5 +1,6 @@
#pragma once
#include <cublasLt.h>
#include <ATen/cuda/CUDAContext.h>
#include "../utils/exception.hpp"
@@ -11,8 +12,28 @@ class DeviceRuntime {
int num_sms = 0, tc_util = 0;
std::shared_ptr<cudaDeviceProp> cached_prop;
// cuBLASLt utils
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
cublasLtHandle_t cublaslt_handle{};
std::shared_ptr<torch::Tensor> cublaslt_workspace;
public:
explicit DeviceRuntime() = default;
explicit DeviceRuntime() {
cublaslt_workspace = std::make_shared<torch::Tensor>(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)));
DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle));
}
~DeviceRuntime() noexcept(false) {
DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle));
}
cublasLtHandle_t get_cublaslt_handle() const {
return cublaslt_handle;
}
torch::Tensor get_cublaslt_workspace() const {
return *cublaslt_workspace;
}
std::shared_ptr<cudaDeviceProp> get_prop() {
if (cached_prop == nullptr)

View File

@@ -1,5 +1,7 @@
#pragma once
#include <deep_gemm/common/types.hpp>
#include "../../utils/math.hpp"
#include "../../utils/layout.hpp"
@@ -80,18 +82,19 @@ static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
return divisible and num_sms % num_multicast == 0;
}
static int get_swizzle_mode(const int& block_size, const int& elem_size) {
template <typename size_type_t>
static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) {
// `> 0` means interleaving
// 16B actually means non-swizzling (but interleaving)
for (const int& mode: {128, 64, 32, 16}) {
if ((block_size * elem_size) % mode == 0)
if ((block_size * static_cast<int>(elem_size)) % mode == 0)
return mode;
}
DG_HOST_UNREACHABLE("Unreachable");
}
template <typename ArchSpec>
static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const KernelType& kernel_type,
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
@@ -104,7 +107,7 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n);
const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size);
const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size);
const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size);
const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0;
// Different archs have different epilogue pipelines
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);
@@ -121,9 +124,11 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
// M-barriers and tensor memory pointers
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages);
const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size();
const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type);
// Sum them up
int smem_size = 0;
smem_size += smem_tensor_map;
smem_size += smem_cd;
smem_size += num_stages * smem_a_per_stage;
smem_size += num_stages * smem_b_per_stage;
@@ -151,15 +156,12 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
// Select M/N block sizes
// TODO: support `% 16 == 8` block size on SM90
auto block_ms = std::vector{64, 128, 256};
if (gemm_type == GemmType::MGroupedContiguous)
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
block_ms = std::vector{64, 128};
std::vector<int> block_ns;
for (int i = 16; i <= 256; i += 16)
block_ns.push_back(i);
const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype);
// K block size is selected in a fixed manner
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
@@ -214,9 +216,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0);
// Decide the number of TMA multicasts and whether broadcast on A
MulticastConfig best_multicast_config = {1, true};
MulticastConfig best_multicast_config = {1, false};
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
gemm_type, m, n, best_block_m, best_block_n, num_sms);
gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms);
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
bool order[2] = {false, true};
if (best_block_m > best_block_n)
@@ -232,11 +234,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
constexpr int smem_capacity = ArchSpec::smem_capacity;
int best_num_stages = 0;
SharedMemoryConfig best_smem_config;
for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) {
for (int num_stages = 12; num_stages > 0; -- num_stages) {
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
continue;
best_smem_config = get_smem_config<ArchSpec>(kernel_type,
best_smem_config = get_smem_config<ArchSpec>(gemm_type, kernel_type,
m, n, k,
best_block_m, best_block_n, block_k,
major_a, major_b,

View File

@@ -12,6 +12,15 @@ namespace deep_gemm {
struct SM100ArchSpec {
static constexpr int smem_capacity = 232448;
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
// 16 is for better SM usage
// Stride 32 is due to low-performance swizzle-16/32B
std::vector<int> candidates = {16};
for (int i = 32; i <= 256; i += 32)
candidates.push_back(i);
return candidates;
}
static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) {
return block_m / (config.is_multicast_on_a ? config.num_multicast : 1);
}
@@ -29,6 +38,10 @@ struct SM100ArchSpec {
return block_n;
}
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
return true;
}
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) {
constexpr int num_utccp_aligned_elems = 128;
@@ -86,7 +99,7 @@ struct SM100ArchSpec {
return false;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
// TODO: support other layouts
@@ -138,12 +151,17 @@ struct SM100ArchSpec {
// TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers
// NOTES: 1D2D kernel will not use the with-SF full barriers
// NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
return num_stages * 8 * 3 + 2 * 8 * 2;
// NOTES: the last barrier is for tensor core utilization control
return num_stages * 8 * 3 + 2 * 8 * 2 + 8;
}
static int get_tmem_ptr_smem_size() {
return 4;
}
static int get_tensormap_smem_size(const GemmType& gemm_type) {
return 0;
}
};
} // namespace deep_gemm

View File

@@ -11,6 +11,15 @@ namespace deep_gemm {
struct SM90ArchSpec {
static constexpr int smem_capacity = 232448;
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
// Avoid bank conflicts for FP32 output
const auto& start = cd_dtype == torch::kFloat ? 8 : 16;
std::vector<int> candidates;
for (int i = start; i <= 256; i += 16)
candidates.push_back(i);
return candidates;
}
static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) {
return block_m;
}
@@ -19,26 +28,35 @@ struct SM90ArchSpec {
return block_n;
}
static int get_cd_store_block_m(const int& block_m) {
return block_m;
static int get_cd_store_block_m(const int& block_m, const bool& single_warpgroup_sync = false) {
constexpr int wgmma_m = 64;
return single_warpgroup_sync ? wgmma_m : block_m;
}
static int get_cd_store_block_n(const int& block_n) {
return block_n;
}
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
return cd_dtype != torch::kFloat;
}
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& block_m, const int& block_n, const int& block_k) {
// FP32 output does not support `block_m == 256`
// SM90 FP32 output does not support `block_m == 256`
if (cd_dtype == at::kFloat and block_m == 256)
return false;
// TODO: more general block N selection
// Must be some fixed block N selections
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 and block_n != 152))
return false;
// Avoid large C/D shared memory for FP32 output
// Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel)
if (block_n > 128 and cd_dtype == torch::kFloat) {
if (kernel_type == KernelType::Kernel1D1D and block_n > 152)
return false;
if (kernel_type == KernelType::KernelNoSF and block_n > 200)
return false;
}
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
// Or too many register spills
@@ -66,9 +84,13 @@ struct SM90ArchSpec {
return true;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
// Disable multicast when the number of k-groups is large (a heuristic)
if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4)
return {false, false};
return {
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
@@ -96,9 +118,10 @@ struct SM90ArchSpec {
int smem_sfa_per_stage = block_m * static_cast<int>(sizeof(float));
int smem_sfb_per_stage = 0;
// TODO: figure out here
if (kernel_type == KernelType::Kernel1D1D)
smem_sfb_per_stage = align(block_n * 4, block_k);
if (kernel_type == KernelType::Kernel1D1D) {
// NOTES: `128` is for 2D TMA alignment requirement
smem_sfb_per_stage = align(block_n * 4, 128);
}
return {smem_sfa_per_stage, smem_sfb_per_stage};
}
@@ -109,13 +132,16 @@ struct SM90ArchSpec {
}
static int get_barrier_smem_size(const int& num_stages) {
// For 1D1D kernels, there is an extra barrier for accumulation
return (num_stages + 1) * 8 * 2;
return num_stages * 8 * 2;
}
static int get_tmem_ptr_smem_size() {
return 0;
}
static int get_tensormap_smem_size(const GemmType& gemm_type) {
return gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast<int>(sizeof(CUtensorMap)) : 0;
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,12 @@
#pragma once
#include <optional>
#include <string>
namespace deep_gemm {
static std::string get_default_epilogue_type(const std::optional<std::string>& epilogue_type) {
return epilogue_type.value_or("EpilogueIdentity");
}
} // namespace deep_gemm

View File

@@ -4,6 +4,8 @@
#include <torch/python.h>
#include "../../utils/math.hpp"
#include "../heuristics/sm90.hpp"
#include "../../utils/system.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
@@ -51,7 +53,11 @@ static std::string to_string(const at::ScalarType& dtype) {
}
}
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) {
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype,
const bool& allow_tf32) {
if (allow_tf32 and dtype == torch::kFloat)
return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32;
switch (dtype) {
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
@@ -61,9 +67,14 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
}
}
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) {
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
if (base != 0) {
DG_HOST_ASSERT(base == 32 and mode == 128);
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
}
switch (mode) {
case 0: return CU_TENSOR_MAP_SWIZZLE_NONE;
case 0:
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
@@ -76,7 +87,8 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
int gmem_inner_dim, int gmem_outer_dim,
int smem_inner_dim, int smem_outer_dim,
const int& gmem_outer_stride,
const int& swizzle_mode) {
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_inner_dim = swizzle_mode / elem_size;
@@ -87,14 +99,42 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
const cuuint32_t elem_strides[2] = {1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n",
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d\n",
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
gmem_outer_stride, swizzle_mode, elem_size);
gmem_outer_stride, swizzle_mode, swizzle_base, elem_size);
}
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()),
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode),
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensor_map;
}
static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
const int& gmem_dim_0, const int& gmem_dim_1, const int& gmem_dim_2,
const int& smem_dim_0, const int& smem_dim_1, const int& smem_dim_2,
const int& gmem_stride_0, const int& gmem_stride_1,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
DG_HOST_ASSERT(smem_dim_0 == swizzle_mode / elem_size);
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[3] = {static_cast<cuuint64_t>(gmem_dim_0), static_cast<cuuint64_t>(gmem_dim_1), static_cast<cuuint64_t>(gmem_dim_2),};
const cuuint32_t smem_dims[3] = {static_cast<cuuint32_t>(smem_dim_0), static_cast<cuuint32_t>(smem_dim_1), static_cast<cuuint32_t>(smem_dim_2)};
const cuuint64_t gmem_strides[2] = {static_cast<cuuint64_t>(gmem_stride_0 * elem_size), static_cast<cuuint64_t>(gmem_stride_1 * elem_size)};
const cuuint32_t elem_strides[3] = {1, 1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making 3D TMA desc: global memory: %d %d %d, shared memory: %d %d %d, outer stride: %d %d, swizzle: %d, elem size: %d\n",
gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2,
gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size);
}
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensor_map;
}
@@ -105,7 +145,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
const int& block_m, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
if (num_groups > 1)
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
@@ -114,7 +155,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
gmem_inner_dim, gmem_outer_dim,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode);
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
@@ -123,7 +165,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
const int& block_n, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
@@ -132,7 +175,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
gmem_inner_dim, gmem_outer_dim * num_groups,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode);
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
@@ -140,15 +184,16 @@ static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
const int& block_m, const int& block_n,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
return make_tma_2d_desc(t,
shape_n, shape_m * num_groups,
block_n, block_m,
outer_stride,
swizzle_mode);
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
@@ -156,7 +201,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
int shape_mn, int shape_k,
const int& block_mn, const int& block_k,
const int& num_groups,
const int& swizzle_mode) {
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
// TODO: maybe swizzle SF as well
@@ -167,7 +213,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
block_mn, 1,
shape_mn,
swizzle_mode);
swizzle_mode, swizzle_base,
allow_tf32);
}
} // namespace deep_gemm

View File

@@ -42,7 +42,7 @@ static void __instantiate_kernel() {{
{}, {}, {},
{},
{}, {}, {},
{}, {},
{},
{}, {},
{}, {},
{},
@@ -56,7 +56,7 @@ static void __instantiate_kernel() {{
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.num_stages,
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
@@ -80,8 +80,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
// TODO: test other Ks
DG_HOST_ASSERT(k % 64 == 0);
const auto& aligned_k = align(k, 64);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::KernelNoSF,
m, n, k, 1, major_a, major_b,
@@ -122,7 +121,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
// Launch
const SM100BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,

View File

@@ -0,0 +1,137 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100BmkBnkMnRuntime final: public LaunchRuntime<SM100BmkBnkMnRuntime> {
public:
struct Args {
int s, m, n, k;
int block_m, block_n, block_k;
int split_factor;
int swizzle_ab_mode, swizzle_cd_mode;
int num_stages;
int num_threads;
LaunchArgs launch_args;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_bmn_bnk_mn_gemm_impl<
{}, {}, {},
{}, {}, {},
{},
{}, {},
{}, {}
>);
}};
)",
args.m, args.n, args.k,
args.block_m, args.block_n, args.block_k,
args.split_factor,
args.swizzle_ab_mode, args.swizzle_cd_mode,
args.num_stages, args.num_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.s, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d));
}
};
static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a,
const torch::Tensor &b,
const torch::Tensor &d,
const int &s, const int &m, const int &n, const int &k) {
constexpr int block_m = 128;
constexpr int block_n = 128;
constexpr int block_k = 64;
constexpr int num_threads = 128;
DG_HOST_ASSERT(k % block_k == 0);
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
const int swizzle_cd_mode = get_swizzle_mode(block_n, static_cast<int>(d.element_size()));
// Get best config
const int num_sms = device_runtime->get_num_sms();
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
const int num_sk_blocks = s * (k / block_k);
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
// Select best number of stages
// NOTES: we select 4 as start, as it is tested to be faster than values > 4
int num_stages = 4, smem_size = 0;
while (true) {
const int& smem_cd = block_m * swizzle_cd_mode * 2;
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_barrier = SM100ArchSpec::get_barrier_smem_size(num_stages);
const int& smem_tmem_ptr = SM100ArchSpec::get_tmem_ptr_smem_size();
smem_size = 0;
smem_size += smem_cd;
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
smem_size += smem_barrier;
smem_size += smem_tmem_ptr;
if (smem_size <= SM100ArchSpec::smem_capacity)
break;
-- num_stages;
}
DG_HOST_ASSERT(num_stages > 0);
// Print configs
if (get_env("DG_JIT_DEBUG", 0)) {
printf("S: %d, M: %d, N: %d, K: %d -> "
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
"stages: %d, shared memory: %d, swizzle AB: %d, swizzle CD: %d\n",
s, m, n, k, block_m, block_n, block_k, split_factor,
num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode);
}
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
const auto& tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode);
const SM100BmkBnkMnRuntime::Args& args = {
.s = s, .m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.split_factor = split_factor,
.swizzle_ab_mode = swizzle_ab_mode,
.swizzle_cd_mode = swizzle_cd_mode,
.num_stages = num_stages,
.num_threads = num_threads,
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100BmkBnkMnRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code);
SM100BmkBnkMnRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -9,6 +9,8 @@
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "epilogue.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
@@ -18,6 +20,7 @@ public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;
GemmConfig gemm_config;
LaunchArgs launch_args;
@@ -44,11 +47,12 @@ static void __instantiate_kernel() {{
{}, {}, {},
{},
{}, {}, {},
{}, {},
{},
{}, {},
{}, {},
{},
{}, {}, {}
{}, {}, {},
{}
>);
}};
)",
@@ -57,11 +61,12 @@ static void __instantiate_kernel() {{
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.num_stages,
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype));
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
get_default_epilogue_type(args.epilogue_type));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -80,7 +85,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
@@ -99,7 +105,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
@@ -129,6 +135,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -186,6 +193,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -244,6 +252,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -324,6 +333,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,

View File

@@ -18,6 +18,7 @@ public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;
GemmConfig gemm_config;
LaunchArgs launch_args;
@@ -46,7 +47,8 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{},
{}, {}
{}, {},
{}
>);
}};
)",
@@ -59,7 +61,8 @@ static void __instantiate_kernel() {{
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype));
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype),
get_default_epilogue_type(args.epilogue_type));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -78,7 +81,8 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
DG_HOST_ASSERT(not c.has_value());
const auto& aligned_k = align(k, 128);
@@ -98,7 +102,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
@@ -111,6 +115,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -164,6 +169,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -218,6 +224,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const t
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,

View File

@@ -41,7 +41,7 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{}, {},
{}, {}
{}, {}, {}
>);
}};
)",
@@ -53,7 +53,8 @@ static void __instantiate_kernel() {{
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
to_string(args.gemm_config.cd_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -73,10 +74,10 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(not c.has_value());
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 64 == 0);
const auto& aligned_k = align(k, 64);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::KernelNoSF,
m, n, k, 1, major_a, major_b,
@@ -102,7 +103,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
// Launch
const SM90BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,

View File

@@ -0,0 +1,131 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90BmkBnkMnRuntime final: public LaunchRuntime<SM90BmkBnkMnRuntime> {
public:
struct Args {
int s, m, n, k;
int block_m, block_n, block_k;
int split_factor;
int num_stages;
int num_tma_threads, num_math_threads;
LaunchArgs launch_args;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
float* d;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_bmn_bnk_mn_gemm_impl<
{}, {}, {},
{}, {}, {},
{},
{},
{}, {}
>);
}};
)",
args.m, args.n, args.k,
args.block_m, args.block_n, args.block_k,
args.split_factor,
args.num_stages,
args.num_tma_threads, args.num_math_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.s, args.tensor_map_a, args.tensor_map_b, args.d));
}
};
static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a,
const torch::Tensor &b,
const torch::Tensor &d,
const int &s, const int &m, const int &n, const int &k) {
constexpr int block_m = 128;
constexpr int block_n = 128;
constexpr int block_k = 64;
constexpr int num_tma_threads = 128;
constexpr int num_math_threads = 256;
DG_HOST_ASSERT(k % block_k == 0);
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
DG_HOST_ASSERT(swizzle_ab_mode == 128);
// Get best config
const int num_sms = device_runtime->get_num_sms();
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
const int num_sk_blocks = s * (k / block_k);
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
// Select best number of stages
int num_stages = 4, smem_size = 0;
while (true) {
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_barrier = SM90ArchSpec::get_barrier_smem_size(num_stages);
smem_size = 0;
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
smem_size += smem_barrier;
if (smem_size <= SM90ArchSpec::smem_capacity)
break;
-- num_stages;
}
DG_HOST_ASSERT(num_stages > 0);
// Print configs
if (get_env("DG_JIT_DEBUG", 0)) {
printf("S: %d, M: %d, N: %d, K: %d -> "
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
"stages: %d, shared memory: %d, swizzle AB: %d\n",
s, m, n, k, block_m, block_n, block_k, split_factor,
num_stages, smem_size, swizzle_ab_mode);
}
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
const SM90BmkBnkMnRuntime::Args& args = {
.s = s, .m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.split_factor = split_factor,
.num_stages = num_stages,
.num_tma_threads = num_tma_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.d = d.data_ptr<float>()
};
const auto& code = SM90BmkBnkMnRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code);
SM90BmkBnkMnRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,214 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime<SM90FP8Gemm1D1DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmConfig gemm_config;
LaunchArgs launch_args;
void *gmem_a_ptr;
void *gmem_b_ptr;
void *grouped_layout;
void *tensor_map_buffer;
CUtensorMap tensor_map_a_base;
CUtensorMap tensor_map_b_base;
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_d;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d1d_impl<
{}, {}, {},
{},
{}, {}, {},
{},
{}, {},
{}, {},
{},
{}, {}
>);
}};
)",
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.num_groups,
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.gemm_config.num_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
to_string(args.gemm_config.cd_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.gmem_a_ptr, args.gmem_b_ptr,
args.grouped_layout,
args.tensor_map_buffer,
args.m, args.n, args.k,
args.tensor_map_a_base, args.tensor_map_b_base,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_d));
}
};
static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k, k, 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k, k, 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, 1, 0);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
0);
// Launch
const SM90FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.gmem_a_ptr = nullptr,
.gmem_b_ptr = nullptr,
.grouped_layout = nullptr,
.tensor_map_buffer = nullptr,
.tensor_map_a_base = tensor_map_a,
.tensor_map_b_base = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_d = tensor_map_d,
};
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
}
static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const torch::Tensor& tensor_map_buffer,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
// Get config using max K for better performance
const auto& num_groups = static_cast<int>(ks.size());
const auto& max_k = *std::max_element(ks.begin(), ks.end());
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
m, n, max_k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
int first_k = 0, sum_k = 0, sum_sf_k = 0;
for (int i = 0; i < num_groups; ++ i) {
if (first_k == 0 and ks[i] != 0)
first_k = ks[i];
sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128);
DG_HOST_ASSERT(ks[i] % 128 == 0);
}
const auto& tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k, first_k, 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k, first_k, 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128,
config.block_n, config.block_k, 1, 0);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
// Launch
const SM90FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.gmem_a_ptr = a.data_ptr(),
.gmem_b_ptr = b.data_ptr(),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_buffer = tensor_map_buffer.data_ptr(),
.tensor_map_a_base = tensor_map_a_base,
.tensor_map_b_base = tensor_map_b_base,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_d = tensor_map_d,
};
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -8,6 +8,8 @@
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../heuristics/sm90.hpp"
#include "epilogue.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
@@ -17,6 +19,7 @@ public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;
GemmConfig gemm_config;
LaunchArgs launch_args;
@@ -43,7 +46,7 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{}, {},
{}, {}
{}, {}, {}
>);
}};
)",
@@ -55,7 +58,8 @@ static void __instantiate_kernel() {{
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
get_default_epilogue_type(args.epilogue_type));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -74,7 +78,8 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
@@ -98,7 +103,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
@@ -111,6 +116,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -170,6 +176,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
@@ -230,6 +237,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,

View File

@@ -0,0 +1,151 @@
#pragma once
#include <cublasLt.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
namespace deep_gemm {
static auto get_cublaslt_layout(const cudaDataType& type, const int& rows, const int& cols, const int& ld,
const std::optional<int>& batch_count = std::nullopt,
const std::optional<int>& batch_offset = std::nullopt) {
cublasLtMatrixLayout_t layout;
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&layout, type, rows, cols, ld));
if (batch_count.has_value()) {
DG_HOST_ASSERT(batch_offset.has_value());
const int64_t batch_offset_int64 = batch_offset.value();
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count.value(), sizeof(batch_count.value())));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset_int64, sizeof(batch_offset_int64)));
}
return layout;
}
static void call_cublaslt_api(const cublasOperation_t& trans_a,
const cublasOperation_t& trans_b,
const cublasLtMatrixLayout_t& layout_a,
const cublasLtMatrixLayout_t& layout_b,
const cublasLtMatrixLayout_t& layout_d,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const bool& accumulate) {
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
cudaDataType_t scale_type = CUDA_R_32F;
const int& math_sms = device_runtime->get_num_sms();
bool fp8_fast_accumulate = false;
// Operation description
cublasLtMatmulDesc_t desc;
DG_CUBLASLT_CHECK(cublasLtMatmulDescCreate(&desc, compute_type, scale_type));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms)));
if (a.scalar_type() == torch::kFloat8_e4m3fn)
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate)));
// Get cuBLASLt handle, workspace, and stream
const auto& handle = device_runtime->get_cublaslt_handle();
const auto& workspace = device_runtime->get_cublaslt_workspace();
const auto& workspace_bytes = workspace.nbytes();
const auto& stream = at::cuda::getCurrentCUDAStream();
// Algorithm selection
cublasLtMatmulPreference_t pref;
cublasLtMatmulHeuristicResult_t heuristic;
int num_heuristic_results = 0;
uint32_t reduction_scheme_mask = CUBLASLT_REDUCTION_SCHEME_NONE | CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE;
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&pref));
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_bytes, sizeof(workspace_bytes)));
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
&reduction_scheme_mask, sizeof(reduction_scheme_mask)));
DG_CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, layout_a, layout_b, layout_d, layout_d,
pref, 1, &heuristic, &num_heuristic_results));
DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM");
// Call: D = alpha * (A @ B) + beta * C
const float& alpha = 1.0, beta = accumulate ? 1.0 : 0.0;
DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle
desc, // Operation description
&alpha, // Alpha
b.data_ptr(), layout_a, // A
a.data_ptr(), layout_b, // B
&beta, // Beta
d.data_ptr(), layout_d, // C
d.data_ptr(), layout_d, // D
&heuristic.algo, // Algorithm
workspace.data_ptr(), workspace_bytes, // Workspace
stream)); // Stream
// Free memory
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceDestroy(pref));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_a));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_b));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_d));
DG_CUBLASLT_CHECK(cublasLtMatmulDescDestroy(desc));
}
static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs,
const std::optional<torch::Tensor>& acc,
const torch::Tensor& out,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major) {
const auto& trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto& trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T;
// Duplicate the accumulator if necessary
// TODO: remove this
if (acc.has_value()) {
if (acc->data_ptr() == out.data_ptr()) {
DG_HOST_ASSERT(acc->sizes() == out.sizes() and acc->strides() == out.strides());
} else {
out.copy_(acc.value());
}
}
// Matrix layouts
const auto& cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type());
const auto& cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type());
const auto& cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type());
const auto& layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0))
: get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1));
const auto& layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0))
: get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, acc.has_value());
}
static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
const int& b, const int& h, const int& r, const int& d) {
const auto& m = d, n = b, k = r;
const auto& trans_a = CUBLAS_OP_T;
const auto& trans_b = CUBLAS_OP_N;
// Matrix layouts
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0));
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
}
static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
const int& b, const int& h, const int& r, const int& d) {
const auto& m = r, n = b, k = d;
const auto& trans_a = CUBLAS_OP_N;
const auto& trans_b = CUBLAS_OP_N;
// Matrix layouts
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0));
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
}
} // namespace deep_gemm

View File

@@ -1,6 +1,8 @@
#include <pybind11/pybind11.h>
#include <torch/python.h>
#include "apis/attention.hpp"
#include "apis/einsum.hpp"
#include "apis/gemm.hpp"
#include "apis/layout.hpp"
#include "apis/runtime.hpp"
@@ -13,6 +15,8 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "DeepGEMM C++ library";
deep_gemm::attention::register_apis(m);
deep_gemm::einsum::register_apis(m);
deep_gemm::gemm::register_apis(m);
deep_gemm::layout::register_apis(m);
deep_gemm::runtime::register_apis(m);

View File

@@ -1,5 +1,6 @@
#pragma once
#include <cublasLt.h>
#include <exception>
#include <string>
#include <sstream>
@@ -72,4 +73,16 @@ do { \
} while (0)
#endif
#ifndef DG_CUBLASLT_CHECK
#define DG_CUBLASLT_CHECK(cmd) \
do { \
const auto& e = (cmd); \
if (e != CUBLAS_STATUS_SUCCESS) { \
std::ostringstream ss; \
ss << static_cast<int>(e) << " (" << cublasGetStatusString(e) << ")"; \
throw DGException("cuBLASLt", __FILE__, __LINE__, ss.str()); \
} \
} while (0)
#endif
} // namespace deep_gemm

View File

@@ -1,11 +1,15 @@
#pragma once
#include <array>
#include <filesystem>
#include <functional>
#include <random>
#include <string>
#include <memory>
#include <unistd.h>
#include "exception.hpp"
#include "format.hpp"
namespace deep_gemm {
@@ -65,7 +69,10 @@ static std::filesystem::path make_dirs(const std::filesystem::path& path) {
// OK if existed
std::error_code capture;
const bool& created = std::filesystem::create_directories(path, capture);
DG_HOST_ASSERT(created or capture.value() == 0);
if (not (created or capture.value() == 0)) {
DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}",
path.c_str(), created, capture.value()));
}
if (created and get_env<int>("DG_JIT_DEBUG"))
printf("Create directory: %s\n", path.c_str());
return path;