Files
DeepGEMM/csrc/jit_kernels/heuristics/sm90.hpp
2025-09-25 16:19:07 +08:00

148 lines
6.3 KiB
C++

#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.hpp>
#include "common.hpp"
namespace deep_gemm {
struct SM90ArchSpec {
static constexpr int smem_capacity = 232448;
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
// Avoid bank conflicts for FP32 output
const auto& start = cd_dtype == torch::kFloat ? 8 : 16;
std::vector<int> candidates;
for (int i = start; i <= 256; i += 16)
candidates.push_back(i);
return candidates;
}
static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) {
return block_m;
}
static int get_ab_load_block_n(const MulticastConfig& multicast_config, const int& block_n) {
return block_n;
}
static int get_cd_store_block_m(const int& block_m, const bool& single_warpgroup_sync = false) {
constexpr int wgmma_m = 64;
return single_warpgroup_sync ? wgmma_m : block_m;
}
static int get_cd_store_block_n(const int& block_n) {
return block_n;
}
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
return cd_dtype != torch::kFloat;
}
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& block_m, const int& block_n, const int& block_k) {
// SM90 FP32 output does not support `block_m == 256`
if (cd_dtype == at::kFloat and block_m == 256)
return false;
// Avoid large C/D shared memory for FP32 output
// Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel)
if (block_n > 128 and cd_dtype == torch::kFloat) {
if (kernel_type == KernelType::Kernel1D1D and block_n > 152)
return false;
if (kernel_type == KernelType::KernelNoSF and block_n > 200)
return false;
}
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
// Or too many register spills
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192))
return false;
// Avoid bank conflicts for FP32 output
if (cd_dtype == torch::kFloat and block_n % 16 == 0)
return false;
// The block sizes cannot be too large (for enough registers), so at least one dim less than 128
return block_m <= 128 or block_n <= 128;
}
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& num_stages,
const int& block_m, const int& block_n, const int& block_k) {
// Unrolling both stages and `num_former_iters` will cause large code size
if (ab_dtype == torch::kFloat8_e4m3fn and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4)
return num_stages <= 4;
return true;
}
static bool should_minimize_num_sms() {
return true;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
// Disable multicast when the number of k-groups is large (a heuristic)
if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4)
return {false, false};
return {
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
is_multicast_legal(m, block_m, 2, num_sms, false)
and (gemm_type != GemmType::MGroupedMasked or is_multicast_legal(n, block_n, 2, num_sms, true))
};
}
static ThreadConfig get_thread_config(const KernelType& kernel_type,
const int& block_m, const int& block_n) {
return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128);
}
static int get_smem_cd_size(const KernelType& kernel_type,
const int& block_m, const int& block_n,
const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) {
return block_m * block_n * static_cast<int>(c10::elementSize(cd_dtype));
}
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
const int& block_m, const int& block_n, const int& block_k,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
if (ab_dtype == torch::kBFloat16)
return {0, 0};
int smem_sfa_per_stage = block_m * static_cast<int>(sizeof(float));
int smem_sfb_per_stage = 0;
if (kernel_type == KernelType::Kernel1D1D) {
// NOTES: `128` is for 2D TMA alignment requirement
smem_sfb_per_stage = align(block_n * 4, 128);
}
return {smem_sfa_per_stage, smem_sfb_per_stage};
}
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2;
return align<int>(ceil_div(k, block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);
}
static int get_barrier_smem_size(const int& num_stages) {
return num_stages * 8 * 2;
}
static int get_tmem_ptr_smem_size() {
return 0;
}
static int get_tensormap_smem_size(const GemmType& gemm_type) {
return gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast<int>(sizeof(CUtensorMap)) : 0;
}
};
} // namespace deep_gemm