Make various updates and fixes (#198)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <deep_gemm/common/types.hpp>
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../../utils/layout.hpp"
|
||||
|
||||
@@ -80,18 +82,19 @@ static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
|
||||
return divisible and num_sms % num_multicast == 0;
|
||||
}
|
||||
|
||||
static int get_swizzle_mode(const int& block_size, const int& elem_size) {
|
||||
template <typename size_type_t>
|
||||
static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) {
|
||||
// `> 0` means interleaving
|
||||
// 16B actually means non-swizzling (but interleaving)
|
||||
for (const int& mode: {128, 64, 32, 16}) {
|
||||
if ((block_size * elem_size) % mode == 0)
|
||||
if ((block_size * static_cast<int>(elem_size)) % mode == 0)
|
||||
return mode;
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unreachable");
|
||||
}
|
||||
|
||||
template <typename ArchSpec>
|
||||
static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
|
||||
static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const KernelType& kernel_type,
|
||||
const int& m, const int& n, const int& k,
|
||||
const int& block_m, const int& block_n, const int& block_k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
@@ -104,7 +107,7 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
|
||||
const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n);
|
||||
const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size);
|
||||
const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size);
|
||||
const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size);
|
||||
const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0;
|
||||
|
||||
// Different archs have different epilogue pipelines
|
||||
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);
|
||||
@@ -121,9 +124,11 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
|
||||
// M-barriers and tensor memory pointers
|
||||
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages);
|
||||
const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size();
|
||||
const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type);
|
||||
|
||||
// Sum them up
|
||||
int smem_size = 0;
|
||||
smem_size += smem_tensor_map;
|
||||
smem_size += smem_cd;
|
||||
smem_size += num_stages * smem_a_per_stage;
|
||||
smem_size += num_stages * smem_b_per_stage;
|
||||
@@ -151,15 +156,12 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
|
||||
|
||||
// Select M/N block sizes
|
||||
// TODO: support `% 16 == 8` block size on SM90
|
||||
auto block_ms = std::vector{64, 128, 256};
|
||||
if (gemm_type == GemmType::MGroupedContiguous)
|
||||
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
|
||||
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
|
||||
block_ms = std::vector{64, 128};
|
||||
std::vector<int> block_ns;
|
||||
for (int i = 16; i <= 256; i += 16)
|
||||
block_ns.push_back(i);
|
||||
const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype);
|
||||
|
||||
// K block size is selected in a fixed manner
|
||||
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
|
||||
@@ -214,9 +216,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0);
|
||||
|
||||
// Decide the number of TMA multicasts and whether broadcast on A
|
||||
MulticastConfig best_multicast_config = {1, true};
|
||||
MulticastConfig best_multicast_config = {1, false};
|
||||
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
|
||||
gemm_type, m, n, best_block_m, best_block_n, num_sms);
|
||||
gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms);
|
||||
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
|
||||
bool order[2] = {false, true};
|
||||
if (best_block_m > best_block_n)
|
||||
@@ -232,11 +234,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
constexpr int smem_capacity = ArchSpec::smem_capacity;
|
||||
int best_num_stages = 0;
|
||||
SharedMemoryConfig best_smem_config;
|
||||
for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) {
|
||||
for (int num_stages = 12; num_stages > 0; -- num_stages) {
|
||||
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
|
||||
continue;
|
||||
|
||||
best_smem_config = get_smem_config<ArchSpec>(kernel_type,
|
||||
best_smem_config = get_smem_config<ArchSpec>(gemm_type, kernel_type,
|
||||
m, n, k,
|
||||
best_block_m, best_block_n, block_k,
|
||||
major_a, major_b,
|
||||
|
||||
@@ -12,6 +12,15 @@ namespace deep_gemm {
|
||||
struct SM100ArchSpec {
|
||||
static constexpr int smem_capacity = 232448;
|
||||
|
||||
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
|
||||
// 16 is for better SM usage
|
||||
// Stride 32 is due to low-performance swizzle-16/32B
|
||||
std::vector<int> candidates = {16};
|
||||
for (int i = 32; i <= 256; i += 32)
|
||||
candidates.push_back(i);
|
||||
return candidates;
|
||||
}
|
||||
|
||||
static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) {
|
||||
return block_m / (config.is_multicast_on_a ? config.num_multicast : 1);
|
||||
}
|
||||
@@ -29,6 +38,10 @@ struct SM100ArchSpec {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
|
||||
const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) {
|
||||
constexpr int num_utccp_aligned_elems = 128;
|
||||
@@ -86,7 +99,7 @@ struct SM100ArchSpec {
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
|
||||
const int& m, const int& n, const int& block_m, const int& block_n,
|
||||
const int& num_sms) {
|
||||
// TODO: support other layouts
|
||||
@@ -138,12 +151,17 @@ struct SM100ArchSpec {
|
||||
// TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers
|
||||
// NOTES: 1D2D kernel will not use the with-SF full barriers
|
||||
// NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
|
||||
return num_stages * 8 * 3 + 2 * 8 * 2;
|
||||
// NOTES: the last barrier is for tensor core utilization control
|
||||
return num_stages * 8 * 3 + 2 * 8 * 2 + 8;
|
||||
}
|
||||
|
||||
static int get_tmem_ptr_smem_size() {
|
||||
return 4;
|
||||
}
|
||||
|
||||
static int get_tensormap_smem_size(const GemmType& gemm_type) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -11,6 +11,15 @@ namespace deep_gemm {
|
||||
struct SM90ArchSpec {
|
||||
static constexpr int smem_capacity = 232448;
|
||||
|
||||
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
|
||||
// Avoid bank conflicts for FP32 output
|
||||
const auto& start = cd_dtype == torch::kFloat ? 8 : 16;
|
||||
std::vector<int> candidates;
|
||||
for (int i = start; i <= 256; i += 16)
|
||||
candidates.push_back(i);
|
||||
return candidates;
|
||||
}
|
||||
|
||||
static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) {
|
||||
return block_m;
|
||||
}
|
||||
@@ -19,26 +28,35 @@ struct SM90ArchSpec {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static int get_cd_store_block_m(const int& block_m) {
|
||||
return block_m;
|
||||
static int get_cd_store_block_m(const int& block_m, const bool& single_warpgroup_sync = false) {
|
||||
constexpr int wgmma_m = 64;
|
||||
return single_warpgroup_sync ? wgmma_m : block_m;
|
||||
}
|
||||
|
||||
static int get_cd_store_block_n(const int& block_n) {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
|
||||
return cd_dtype != torch::kFloat;
|
||||
}
|
||||
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// FP32 output does not support `block_m == 256`
|
||||
// SM90 FP32 output does not support `block_m == 256`
|
||||
if (cd_dtype == at::kFloat and block_m == 256)
|
||||
return false;
|
||||
|
||||
// TODO: more general block N selection
|
||||
// Must be some fixed block N selections
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 and block_n != 152))
|
||||
return false;
|
||||
// Avoid large C/D shared memory for FP32 output
|
||||
// Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel)
|
||||
if (block_n > 128 and cd_dtype == torch::kFloat) {
|
||||
if (kernel_type == KernelType::Kernel1D1D and block_n > 152)
|
||||
return false;
|
||||
if (kernel_type == KernelType::KernelNoSF and block_n > 200)
|
||||
return false;
|
||||
}
|
||||
|
||||
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
|
||||
// Or too many register spills
|
||||
@@ -66,9 +84,13 @@ struct SM90ArchSpec {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
|
||||
const int& m, const int& n, const int& block_m, const int& block_n,
|
||||
const int& num_sms) {
|
||||
// Disable multicast when the number of k-groups is large (a heuristic)
|
||||
if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4)
|
||||
return {false, false};
|
||||
|
||||
return {
|
||||
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
|
||||
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
|
||||
@@ -96,9 +118,10 @@ struct SM90ArchSpec {
|
||||
|
||||
int smem_sfa_per_stage = block_m * static_cast<int>(sizeof(float));
|
||||
int smem_sfb_per_stage = 0;
|
||||
// TODO: figure out here
|
||||
if (kernel_type == KernelType::Kernel1D1D)
|
||||
smem_sfb_per_stage = align(block_n * 4, block_k);
|
||||
if (kernel_type == KernelType::Kernel1D1D) {
|
||||
// NOTES: `128` is for 2D TMA alignment requirement
|
||||
smem_sfb_per_stage = align(block_n * 4, 128);
|
||||
}
|
||||
return {smem_sfa_per_stage, smem_sfb_per_stage};
|
||||
}
|
||||
|
||||
@@ -109,13 +132,16 @@ struct SM90ArchSpec {
|
||||
}
|
||||
|
||||
static int get_barrier_smem_size(const int& num_stages) {
|
||||
// For 1D1D kernels, there is an extra barrier for accumulation
|
||||
return (num_stages + 1) * 8 * 2;
|
||||
return num_stages * 8 * 2;
|
||||
}
|
||||
|
||||
static int get_tmem_ptr_smem_size() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int get_tensormap_smem_size(const GemmType& gemm_type) {
|
||||
return gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast<int>(sizeof(CUtensorMap)) : 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
12
csrc/jit_kernels/impls/epilogue.hpp
Normal file
12
csrc/jit_kernels/impls/epilogue.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
static std::string get_default_epilogue_type(const std::optional<std::string>& epilogue_type) {
|
||||
return epilogue_type.value_or("EpilogueIdentity");
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -4,6 +4,8 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "../../utils/system.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
@@ -51,7 +53,11 @@ static std::string to_string(const at::ScalarType& dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) {
|
||||
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype,
|
||||
const bool& allow_tf32) {
|
||||
if (allow_tf32 and dtype == torch::kFloat)
|
||||
return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32;
|
||||
|
||||
switch (dtype) {
|
||||
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
@@ -61,9 +67,14 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) {
|
||||
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
|
||||
if (base != 0) {
|
||||
DG_HOST_ASSERT(base == 32 and mode == 128);
|
||||
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
|
||||
}
|
||||
|
||||
switch (mode) {
|
||||
case 0: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case 0:
|
||||
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
@@ -76,7 +87,8 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
int gmem_inner_dim, int gmem_outer_dim,
|
||||
int smem_inner_dim, int smem_outer_dim,
|
||||
const int& gmem_outer_stride,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
const auto& elem_size = static_cast<int>(t.element_size());
|
||||
if (swizzle_mode != 0)
|
||||
smem_inner_dim = swizzle_mode / elem_size;
|
||||
@@ -87,14 +99,42 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
|
||||
const cuuint32_t elem_strides[2] = {1, 1};
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n",
|
||||
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d\n",
|
||||
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
|
||||
gmem_outer_stride, swizzle_mode, elem_size);
|
||||
gmem_outer_stride, swizzle_mode, swizzle_base, elem_size);
|
||||
}
|
||||
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()),
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
|
||||
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode),
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
|
||||
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
|
||||
const int& gmem_dim_0, const int& gmem_dim_1, const int& gmem_dim_2,
|
||||
const int& smem_dim_0, const int& smem_dim_1, const int& smem_dim_2,
|
||||
const int& gmem_stride_0, const int& gmem_stride_1,
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
const auto& elem_size = static_cast<int>(t.element_size());
|
||||
if (swizzle_mode != 0)
|
||||
DG_HOST_ASSERT(smem_dim_0 == swizzle_mode / elem_size);
|
||||
|
||||
CUtensorMap tensor_map;
|
||||
const cuuint64_t gmem_dims[3] = {static_cast<cuuint64_t>(gmem_dim_0), static_cast<cuuint64_t>(gmem_dim_1), static_cast<cuuint64_t>(gmem_dim_2),};
|
||||
const cuuint32_t smem_dims[3] = {static_cast<cuuint32_t>(smem_dim_0), static_cast<cuuint32_t>(smem_dim_1), static_cast<cuuint32_t>(smem_dim_2)};
|
||||
const cuuint64_t gmem_strides[2] = {static_cast<cuuint64_t>(gmem_stride_0 * elem_size), static_cast<cuuint64_t>(gmem_stride_1 * elem_size)};
|
||||
const cuuint32_t elem_strides[3] = {1, 1, 1};
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
printf("Making 3D TMA desc: global memory: %d %d %d, shared memory: %d %d %d, outer stride: %d %d, swizzle: %d, elem size: %d\n",
|
||||
gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2,
|
||||
gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size);
|
||||
}
|
||||
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
|
||||
3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
|
||||
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
||||
return tensor_map;
|
||||
}
|
||||
@@ -105,7 +145,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
|
||||
const int& block_m, const int& block_k,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
if (num_groups > 1)
|
||||
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
|
||||
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
|
||||
@@ -114,7 +155,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
|
||||
gmem_inner_dim, gmem_outer_dim,
|
||||
smem_inner_dim, smem_outer_dim,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
@@ -123,7 +165,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
const int& block_n, const int& block_k,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
|
||||
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
|
||||
|
||||
@@ -132,7 +175,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
gmem_inner_dim, gmem_outer_dim * num_groups,
|
||||
smem_inner_dim, smem_outer_dim,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
|
||||
@@ -140,15 +184,16 @@ static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
|
||||
const int& block_m, const int& block_n,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
|
||||
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
|
||||
return make_tma_2d_desc(t,
|
||||
shape_n, shape_m * num_groups,
|
||||
block_n, block_m,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
@@ -156,7 +201,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
int shape_mn, int shape_k,
|
||||
const int& block_mn, const int& block_k,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
|
||||
|
||||
// TODO: maybe swizzle SF as well
|
||||
@@ -167,7 +213,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
|
||||
block_mn, 1,
|
||||
shape_mn,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -42,7 +42,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
@@ -56,7 +56,7 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.num_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
@@ -80,8 +80,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
// TODO: test other Ks
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
const auto& aligned_k = align(k, 64);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
@@ -122,7 +121,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
|
||||
// Launch
|
||||
const SM100BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
|
||||
137
csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp
Normal file
137
csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp
Normal file
@@ -0,0 +1,137 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100BmkBnkMnRuntime final: public LaunchRuntime<SM100BmkBnkMnRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int s, m, n, k;
|
||||
int block_m, block_n, block_k;
|
||||
int split_factor;
|
||||
int swizzle_ab_mode, swizzle_cd_mode;
|
||||
int num_stages;
|
||||
int num_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_bmn_bnk_mn_gemm_impl<
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
args.m, args.n, args.k,
|
||||
args.block_m, args.block_n, args.block_k,
|
||||
args.split_factor,
|
||||
args.swizzle_ab_mode, args.swizzle_cd_mode,
|
||||
args.num_stages, args.num_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.s, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a,
|
||||
const torch::Tensor &b,
|
||||
const torch::Tensor &d,
|
||||
const int &s, const int &m, const int &n, const int &k) {
|
||||
constexpr int block_m = 128;
|
||||
constexpr int block_n = 128;
|
||||
constexpr int block_k = 64;
|
||||
constexpr int num_threads = 128;
|
||||
DG_HOST_ASSERT(k % block_k == 0);
|
||||
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
|
||||
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
|
||||
|
||||
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
|
||||
const int swizzle_cd_mode = get_swizzle_mode(block_n, static_cast<int>(d.element_size()));
|
||||
|
||||
// Get best config
|
||||
const int num_sms = device_runtime->get_num_sms();
|
||||
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
|
||||
const int num_sk_blocks = s * (k / block_k);
|
||||
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
|
||||
|
||||
// Select best number of stages
|
||||
// NOTES: we select 4 as start, as it is tested to be faster than values > 4
|
||||
int num_stages = 4, smem_size = 0;
|
||||
while (true) {
|
||||
const int& smem_cd = block_m * swizzle_cd_mode * 2;
|
||||
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_barrier = SM100ArchSpec::get_barrier_smem_size(num_stages);
|
||||
const int& smem_tmem_ptr = SM100ArchSpec::get_tmem_ptr_smem_size();
|
||||
|
||||
smem_size = 0;
|
||||
smem_size += smem_cd;
|
||||
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
|
||||
smem_size += smem_barrier;
|
||||
smem_size += smem_tmem_ptr;
|
||||
if (smem_size <= SM100ArchSpec::smem_capacity)
|
||||
break;
|
||||
|
||||
-- num_stages;
|
||||
}
|
||||
DG_HOST_ASSERT(num_stages > 0);
|
||||
|
||||
// Print configs
|
||||
if (get_env("DG_JIT_DEBUG", 0)) {
|
||||
printf("S: %d, M: %d, N: %d, K: %d -> "
|
||||
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
|
||||
"stages: %d, shared memory: %d, swizzle AB: %d, swizzle CD: %d\n",
|
||||
s, m, n, k, block_m, block_n, block_k, split_factor,
|
||||
num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode);
|
||||
}
|
||||
|
||||
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
|
||||
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
|
||||
const auto& tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode);
|
||||
|
||||
const SM100BmkBnkMnRuntime::Args& args = {
|
||||
.s = s, .m = m, .n = n, .k = k,
|
||||
.block_m = block_m, .block_n = block_n, .block_k = block_k,
|
||||
.split_factor = split_factor,
|
||||
.swizzle_ab_mode = swizzle_ab_mode,
|
||||
.swizzle_cd_mode = swizzle_cd_mode,
|
||||
.num_stages = num_stages,
|
||||
.num_threads = num_threads,
|
||||
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100BmkBnkMnRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code);
|
||||
SM100BmkBnkMnRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -9,6 +9,8 @@
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
|
||||
#include "epilogue.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
@@ -18,6 +20,7 @@ public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
const std::optional<std::string>& epilogue_type;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
@@ -44,11 +47,12 @@ static void __instantiate_kernel() {{
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}, {}
|
||||
{}, {}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -57,11 +61,12 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.num_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype));
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
|
||||
get_default_epilogue_type(args.epilogue_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -80,7 +85,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D1D,
|
||||
@@ -99,7 +105,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
@@ -129,6 +135,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -186,6 +193,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -244,6 +252,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -324,6 +333,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
|
||||
.m = m, .n = n, .k = sum_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
|
||||
@@ -18,6 +18,7 @@ public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
const std::optional<std::string>& epilogue_type;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
@@ -46,7 +47,8 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
{}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -59,7 +61,8 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype));
|
||||
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype),
|
||||
get_default_epilogue_type(args.epilogue_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -78,7 +81,8 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
@@ -98,7 +102,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
@@ -111,6 +115,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -164,6 +169,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -218,6 +224,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const t
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
|
||||
@@ -41,7 +41,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -53,7 +53,8 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -73,10 +74,10 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& aligned_k = align(k, 64);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
@@ -102,7 +103,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
|
||||
131
csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp
Normal file
131
csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp
Normal file
@@ -0,0 +1,131 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90BmkBnkMnRuntime final: public LaunchRuntime<SM90BmkBnkMnRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int s, m, n, k;
|
||||
int block_m, block_n, block_k;
|
||||
int split_factor;
|
||||
int num_stages;
|
||||
int num_tma_threads, num_math_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
float* d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_bmn_bnk_mn_gemm_impl<
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
args.m, args.n, args.k,
|
||||
args.block_m, args.block_n, args.block_k,
|
||||
args.split_factor,
|
||||
args.num_stages,
|
||||
args.num_tma_threads, args.num_math_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.s, args.tensor_map_a, args.tensor_map_b, args.d));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a,
|
||||
const torch::Tensor &b,
|
||||
const torch::Tensor &d,
|
||||
const int &s, const int &m, const int &n, const int &k) {
|
||||
constexpr int block_m = 128;
|
||||
constexpr int block_n = 128;
|
||||
constexpr int block_k = 64;
|
||||
constexpr int num_tma_threads = 128;
|
||||
constexpr int num_math_threads = 256;
|
||||
DG_HOST_ASSERT(k % block_k == 0);
|
||||
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
|
||||
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
|
||||
|
||||
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
|
||||
DG_HOST_ASSERT(swizzle_ab_mode == 128);
|
||||
|
||||
// Get best config
|
||||
const int num_sms = device_runtime->get_num_sms();
|
||||
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
|
||||
const int num_sk_blocks = s * (k / block_k);
|
||||
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
|
||||
|
||||
// Select best number of stages
|
||||
int num_stages = 4, smem_size = 0;
|
||||
while (true) {
|
||||
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_barrier = SM90ArchSpec::get_barrier_smem_size(num_stages);
|
||||
|
||||
smem_size = 0;
|
||||
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
|
||||
smem_size += smem_barrier;
|
||||
|
||||
if (smem_size <= SM90ArchSpec::smem_capacity)
|
||||
break;
|
||||
|
||||
-- num_stages;
|
||||
}
|
||||
DG_HOST_ASSERT(num_stages > 0);
|
||||
|
||||
// Print configs
|
||||
if (get_env("DG_JIT_DEBUG", 0)) {
|
||||
printf("S: %d, M: %d, N: %d, K: %d -> "
|
||||
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
|
||||
"stages: %d, shared memory: %d, swizzle AB: %d\n",
|
||||
s, m, n, k, block_m, block_n, block_k, split_factor,
|
||||
num_stages, smem_size, swizzle_ab_mode);
|
||||
}
|
||||
|
||||
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
|
||||
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
|
||||
|
||||
const SM90BmkBnkMnRuntime::Args& args = {
|
||||
.s = s, .m = m, .n = n, .k = k,
|
||||
.block_m = block_m, .block_n = block_n, .block_k = block_k,
|
||||
.split_factor = split_factor,
|
||||
.num_stages = num_stages,
|
||||
.num_tma_threads = num_tma_threads,
|
||||
.num_math_threads = num_math_threads,
|
||||
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.d = d.data_ptr<float>()
|
||||
};
|
||||
const auto& code = SM90BmkBnkMnRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code);
|
||||
SM90BmkBnkMnRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
214
csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp
Normal file
214
csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp
Normal file
@@ -0,0 +1,214 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime<SM90FP8Gemm1D1DRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void *gmem_a_ptr;
|
||||
void *gmem_b_ptr;
|
||||
void *grouped_layout;
|
||||
void *tensor_map_buffer;
|
||||
CUtensorMap tensor_map_a_base;
|
||||
CUtensorMap tensor_map_b_base;
|
||||
CUtensorMap tensor_map_sfa;
|
||||
CUtensorMap tensor_map_sfb;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d1d_impl<
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.num_groups,
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.gemm_config.num_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.gmem_a_ptr, args.gmem_b_ptr,
|
||||
args.grouped_layout,
|
||||
args.tensor_map_buffer,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a_base, args.tensor_map_b_base,
|
||||
args.tensor_map_sfa, args.tensor_map_sfb,
|
||||
args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D1D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k, k, 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k, k, 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
0);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.gmem_a_ptr = nullptr,
|
||||
.gmem_b_ptr = nullptr,
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_buffer = nullptr,
|
||||
.tensor_map_a_base = tensor_map_a,
|
||||
.tensor_map_b_base = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
|
||||
|
||||
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n,
|
||||
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
|
||||
const torch::Tensor& tensor_map_buffer,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
// Get config using max K for better performance
|
||||
const auto& num_groups = static_cast<int>(ks.size());
|
||||
const auto& max_k = *std::max_element(ks.begin(), ks.end());
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, max_k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
|
||||
int first_k = 0, sum_k = 0, sum_sf_k = 0;
|
||||
for (int i = 0; i < num_groups; ++ i) {
|
||||
if (first_k == 0 and ks[i] != 0)
|
||||
first_k = ks[i];
|
||||
sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128);
|
||||
DG_HOST_ASSERT(ks[i] % 128 == 0);
|
||||
}
|
||||
const auto& tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k, first_k, 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k, first_k, 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128,
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = sum_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.gmem_a_ptr = a.data_ptr(),
|
||||
.gmem_b_ptr = b.data_ptr(),
|
||||
.grouped_layout = ks_tensor.data_ptr(),
|
||||
.tensor_map_buffer = tensor_map_buffer.data_ptr(),
|
||||
.tensor_map_a_base = tensor_map_a_base,
|
||||
.tensor_map_b_base = tensor_map_b_base,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
|
||||
|
||||
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
|
||||
#include "epilogue.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
@@ -17,6 +19,7 @@ public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
const std::optional<std::string>& epilogue_type;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
@@ -43,7 +46,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -55,7 +58,8 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
|
||||
get_default_epilogue_type(args.epilogue_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -74,7 +78,8 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
@@ -98,7 +103,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
@@ -111,6 +116,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -170,6 +176,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -230,6 +237,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
|
||||
151
csrc/jit_kernels/impls/smxx_cublaslt.hpp
Normal file
151
csrc/jit_kernels/impls/smxx_cublaslt.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
static auto get_cublaslt_layout(const cudaDataType& type, const int& rows, const int& cols, const int& ld,
|
||||
const std::optional<int>& batch_count = std::nullopt,
|
||||
const std::optional<int>& batch_offset = std::nullopt) {
|
||||
cublasLtMatrixLayout_t layout;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&layout, type, rows, cols, ld));
|
||||
if (batch_count.has_value()) {
|
||||
DG_HOST_ASSERT(batch_offset.has_value());
|
||||
|
||||
const int64_t batch_offset_int64 = batch_offset.value();
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count.value(), sizeof(batch_count.value())));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset_int64, sizeof(batch_offset_int64)));
|
||||
}
|
||||
return layout;
|
||||
}
|
||||
|
||||
static void call_cublaslt_api(const cublasOperation_t& trans_a,
|
||||
const cublasOperation_t& trans_b,
|
||||
const cublasLtMatrixLayout_t& layout_a,
|
||||
const cublasLtMatrixLayout_t& layout_b,
|
||||
const cublasLtMatrixLayout_t& layout_d,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const bool& accumulate) {
|
||||
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
cudaDataType_t scale_type = CUDA_R_32F;
|
||||
const int& math_sms = device_runtime->get_num_sms();
|
||||
bool fp8_fast_accumulate = false;
|
||||
|
||||
// Operation description
|
||||
cublasLtMatmulDesc_t desc;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescCreate(&desc, compute_type, scale_type));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms)));
|
||||
if (a.scalar_type() == torch::kFloat8_e4m3fn)
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate)));
|
||||
|
||||
// Get cuBLASLt handle, workspace, and stream
|
||||
const auto& handle = device_runtime->get_cublaslt_handle();
|
||||
const auto& workspace = device_runtime->get_cublaslt_workspace();
|
||||
const auto& workspace_bytes = workspace.nbytes();
|
||||
const auto& stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Algorithm selection
|
||||
cublasLtMatmulPreference_t pref;
|
||||
cublasLtMatmulHeuristicResult_t heuristic;
|
||||
int num_heuristic_results = 0;
|
||||
uint32_t reduction_scheme_mask = CUBLASLT_REDUCTION_SCHEME_NONE | CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&pref));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_bytes, sizeof(workspace_bytes)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
&reduction_scheme_mask, sizeof(reduction_scheme_mask)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, layout_a, layout_b, layout_d, layout_d,
|
||||
pref, 1, &heuristic, &num_heuristic_results));
|
||||
DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM");
|
||||
|
||||
// Call: D = alpha * (A @ B) + beta * C
|
||||
const float& alpha = 1.0, beta = accumulate ? 1.0 : 0.0;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle
|
||||
desc, // Operation description
|
||||
&alpha, // Alpha
|
||||
b.data_ptr(), layout_a, // A
|
||||
a.data_ptr(), layout_b, // B
|
||||
&beta, // Beta
|
||||
d.data_ptr(), layout_d, // C
|
||||
d.data_ptr(), layout_d, // D
|
||||
&heuristic.algo, // Algorithm
|
||||
workspace.data_ptr(), workspace_bytes, // Workspace
|
||||
stream)); // Stream
|
||||
|
||||
// Free memory
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceDestroy(pref));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_a));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_b));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_d));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescDestroy(desc));
|
||||
}
|
||||
|
||||
static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs,
|
||||
const std::optional<torch::Tensor>& acc,
|
||||
const torch::Tensor& out,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major) {
|
||||
const auto& trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
const auto& trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
// TODO: remove this
|
||||
if (acc.has_value()) {
|
||||
if (acc->data_ptr() == out.data_ptr()) {
|
||||
DG_HOST_ASSERT(acc->sizes() == out.sizes() and acc->strides() == out.strides());
|
||||
} else {
|
||||
out.copy_(acc.value());
|
||||
}
|
||||
}
|
||||
|
||||
// Matrix layouts
|
||||
const auto& cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type());
|
||||
const auto& cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type());
|
||||
const auto& cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type());
|
||||
const auto& layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0))
|
||||
: get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1));
|
||||
const auto& layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0))
|
||||
: get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1));
|
||||
const auto& layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0));
|
||||
|
||||
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, acc.has_value());
|
||||
}
|
||||
|
||||
|
||||
static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
|
||||
const int& b, const int& h, const int& r, const int& d) {
|
||||
const auto& m = d, n = b, k = r;
|
||||
const auto& trans_a = CUBLAS_OP_T;
|
||||
const auto& trans_b = CUBLAS_OP_N;
|
||||
|
||||
// Matrix layouts
|
||||
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0));
|
||||
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
|
||||
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
|
||||
|
||||
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
|
||||
}
|
||||
|
||||
|
||||
static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
|
||||
const int& b, const int& h, const int& r, const int& d) {
|
||||
const auto& m = r, n = b, k = d;
|
||||
const auto& trans_a = CUBLAS_OP_N;
|
||||
const auto& trans_b = CUBLAS_OP_N;
|
||||
|
||||
// Matrix layouts
|
||||
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0));
|
||||
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
|
||||
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
|
||||
|
||||
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
Reference in New Issue
Block a user