Files
DeepGEMM/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp
Chenggang Zhao 7f2a703ed5 [Public release 26/04] Introducing Mega MoE, FP4 Indexer and other features/fixes (#304)
* Merge with private repo

* Update README

* Update README

* Update README

* Add PyTorch requirements

* Fix sync scopes for MQA logits (#256)

* Update README
2026-04-17 09:45:14 +08:00

433 lines
22 KiB
C++

#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90BF16GemmRuntime final: public LaunchRuntime<SM90BF16GemmRuntime> {
public:
struct Args {
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
void *grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_cd;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_bf16_gemm_impl<
{}, {},
{}, {}, {},
{},
{}, {}, {},
{}, {}, {},
{},
{}, {},
{}, {},
{},
{}, {},
{}
>);
}};
)",
// TODO: add CD dtype
to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b),
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_desc.num_groups,
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_config.storage_config.swizzle_a_mode,
args.gemm_config.storage_config.swizzle_b_mode,
args.gemm_config.storage_config.swizzle_cd_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads,
// TODO: refactor with cluster M/N
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms,
to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation,
to_string(args.gemm_desc.cd_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.grouped_layout,
args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_cd));
}
};
static void sm90_bf16_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_gemm", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 64 == 0);
const auto gemm_type = use_psum_layout ?
GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// Only psum layout can use expected m
if (expected_m_for_psum_layout)
DG_HOST_ASSERT(use_psum_layout);
const auto desc = GemmDesc {
.gemm_type = gemm_type,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m_for_psum_layout.value_or(m),
.expected_n = n, .expected_k = k,
.expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 64 == 0);
const auto desc = GemmDesc {
.gemm_type = GemmType::MGroupedMasked,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m, .expected_n = 0, .expected_k = 0, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
int sum_k = 0;
for (const auto k: ks) {
sum_k += k;
DG_HOST_ASSERT(k % 128 == 0);
}
const auto num_groups = static_cast<int>(ks.size());
// Get config using max K for better performance
const auto max_k = *std::max_element(ks.begin(), ks.end());
const auto desc = GemmDesc {
.gemm_type = GemmType::KGroupedContiguous,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = sum_k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Create tensor descriptors
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(0)), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(0)), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(1)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch kernel
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_k_grouped_gemm", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
const torch::Tensor& tensor_b,
const torch::Tensor& tensor_d,
const int& b, const int& h, const int& r, const int& d,
const std::string& compiled_dims = "nk") {
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::KernelNoSF,
.m = b, .n = d, .k = r, .num_groups = h,
.a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(),
.cd_dtype = tensor_d.scalar_type(),
.major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
const int load_block_m = config.storage_config.load_block_m;
const auto tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h,
config.layout.block_k, load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.storage_config.swizzle_a_mode);
const int load_block_n = config.storage_config.load_block_n;
const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
config.layout.block_k, load_block_n, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.storage_config.swizzle_b_mode);
const int store_block_m = config.storage_config.store_block_m;
const int store_block_n = config.storage_config.store_block_n;
const auto tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h,
store_block_n, store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
const torch::Tensor& tensor_b,
const torch::Tensor& tensor_d,
const int& b, const int& h, const int& r, const int& d,
const std::string& compiled_dims = "nk") {
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::KernelNoSF,
.m = b, .n = r, .k = d, .num_groups = h,
.a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(),
.cd_dtype = tensor_d.scalar_type(),
.major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::MN,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
const int load_block_m = config.storage_config.load_block_m;
const auto tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h,
config.layout.block_k, load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.storage_config.swizzle_a_mode);
const int load_block_n = config.storage_config.load_block_n;
const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
load_block_n, config.layout.block_k, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.storage_config.swizzle_b_mode);
const int store_block_m = config.storage_config.store_block_m;
const int store_block_n = config.storage_config.store_block_n;
const auto tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h,
store_block_n, store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
} // namespace deep_gemm