Add more GPU architectures support (#112)

* Add more GPU architectures support

* Update layout.py

* Optimize performance, Add SM90 support, Add 1D2D SM100 support

* Add fmtlib submodule at commit 553ec11

---------

Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
Ray Wang
2025-07-18 11:32:22 +08:00
committed by GitHub
parent 03d0be3d2d
commit 9da4a23561
67 changed files with 5586 additions and 2965 deletions

View File

@@ -0,0 +1,298 @@
#pragma once
#include "../../utils/math.hpp"
namespace deep_gemm {
struct MulticastConfig {
int num_multicast;
bool is_multicast_on_a;
MulticastConfig(const int& num_multicast, const bool& is_multicast_on_a):
num_multicast(num_multicast), is_multicast_on_a(is_multicast_on_a) {
DG_HOST_ASSERT(1 <= num_multicast and num_multicast <= 2);
}
};
struct SharedMemoryConfig {
int smem_size;
int swizzle_a_mode;
int swizzle_b_mode;
int swizzle_cd_mode;
};
struct ThreadConfig {
int num_threads;
// SM90
int num_tma_threads;
int num_math_threads;
// SM100
int num_non_epilogue_threads;
int num_epilogue_threads;
static ThreadConfig sm90(const int& num_tma_threads,
const int& num_math_threads) {
auto config = ThreadConfig();
config.num_threads = num_tma_threads + num_math_threads;
config.num_tma_threads = num_tma_threads;
config.num_math_threads = num_math_threads;
return config;
}
static ThreadConfig sm100(const int& num_non_epilogue_threads,
const int& num_epilogue_threads) {
auto config = ThreadConfig();
config.num_threads = num_non_epilogue_threads + num_epilogue_threads;
config.num_non_epilogue_threads = num_non_epilogue_threads;
config.num_epilogue_threads = num_epilogue_threads;
return config;
}
};
struct GemmConfig {
// Templated configs
GemmType gemm_type;
KernelType kernel_type;
at::ScalarType ab_dtype, cd_dtype;
cute::UMMA::Major major_a;
cute::UMMA::Major major_b;
bool with_accumulation;
int block_m, block_n, block_k;
int num_stages, num_last_stages;
// Runtime configs
int num_sms;
// Structured configs
MulticastConfig multicast_config;
SharedMemoryConfig smem_config;
ThreadConfig thread_config;
};
static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
const int& num_multicast, const int& num_sms,
const bool& require_divisible) {
const bool& divisible = ceil_div(shape_dim, block_dim) % num_multicast == 0 or not require_divisible;
return divisible and num_sms % num_multicast == 0;
}
static int get_swizzle_mode(const int& block_size, const int& elem_size) {
// `> 0` means interleaving
// 16B actually means non-swizzling (but interleaving)
for (const int& mode: {128, 64, 32, 16}) {
if ((block_size * elem_size) % mode == 0)
return mode;
}
DG_HOST_UNREACHABLE("Unreachable");
}
template <typename ArchSpec>
static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& num_stages, const MulticastConfig& multicast_config) {
const int& ab_elem_size = static_cast<int>(c10::elementSize(ab_dtype));
const int& cd_elem_size = static_cast<int>(c10::elementSize(cd_dtype));
const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m);
const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n);
const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size);
const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size);
const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size);
// Different archs have different epilogue pipelines
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);
// A/B shared memory
const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size;
const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size;
// SF shared memory
const auto& [smem_sfa_per_stage, smem_sfb_per_stage] =
ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, ab_dtype, cd_dtype);
const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k);
// M-barriers and tensor memory pointers
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages);
const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size();
// Sum them up
int smem_size = 0;
smem_size += smem_cd;
smem_size += num_stages * smem_a_per_stage;
smem_size += num_stages * smem_b_per_stage;
smem_size += num_stages * smem_sfa_per_stage;
smem_size += num_stages * smem_sfb_per_stage;
smem_size += smem_extra_sfb;
smem_size += smem_barrier;
smem_size += smem_tmem_ptr;
return SharedMemoryConfig {
.smem_size = smem_size,
.swizzle_a_mode = swizzle_a_mode,
.swizzle_b_mode = swizzle_b_mode,
.swizzle_cd_mode = swizzle_cd_mode,
};
}
template <typename ArchSpec>
static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type,
const int& m, const int& n, const int& k, const int& num_groups,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const bool& with_accumulation, const int& num_sms) {
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
// Select M/N block sizes
// TODO: support `% 16 == 8` block size on SM90
const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ?
std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256};
std::vector<int> block_ns;
for (int i = 16; i <= 256; i += 16)
block_ns.push_back(i);
// K block size is selected in a fixed manner
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
// Some util functions
const auto& get_num_blocks = [=](const int& block_m, const int& block_n) {
return ceil_div(m, block_m) * ceil_div(n, block_n) * num_groups;
};
const auto& get_num_waves = [=](const int& block_m, const int& block_n) {
return ceil_div(get_num_blocks(block_m, block_n), num_sms);
};
const auto& get_last_wave_util = [=](const int& block_m, const int& block_n) {
const auto& num_last_blocks = get_num_blocks(block_m, block_n) % num_sms;
return num_last_blocks == 0 ? num_sms : num_last_blocks;
};
// Decide block sizes by waves
int best_block_m = 0, best_block_n = 0;
int best_num_waves = 0, best_last_util = 0;
for (const auto& block_m: block_ms) {
for (const auto& block_n: block_ns) {
const int& num_waves = get_num_waves(block_m, block_n);
const auto& last_util = get_last_wave_util(block_m, block_n);
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n))
continue;
bool success = false;
if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) {
success = true;
} else if (num_waves == best_num_waves) {
// Check last wave utilization
success = last_util > best_last_util;
if (last_util == best_last_util) {
// Case 1: same `block_m`, smaller `block_n` (wasted)
success |= block_m == best_block_m and block_n < best_block_n;
// Case 2: same `block_n`, smaller `block_m` (wasted)
success |= block_n == best_block_n and block_m < best_block_m;
// Case 3: different for both `block_m` and `block_n`, larger `block_n` is better
success |= block_m != best_block_m and block_n > best_block_n;
}
}
// Replace with the new config if successful
if (success) {
best_block_m = block_m, best_block_n = block_n;
best_num_waves = num_waves, best_last_util = last_util;
}
}
}
DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0);
// Decide the number of TMA multicasts and whether broadcast on A
MulticastConfig best_multicast_config = {1, true};
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
gemm_type, m, n, best_block_m, best_block_n, num_sms);
const bool is_legal[2] = {is_legal_on_a, is_legal_on_b};
bool order[2] = {false, true};
if (best_block_m > best_block_n)
std::swap(order[0], order[1]);
for (const bool& is_multicast_on_a: order) {
if (m >= 512 and is_legal[static_cast<int>(is_multicast_on_a)]) {
best_multicast_config = {2, is_multicast_on_a};
break;
}
}
// Always pick the largest number of stage
constexpr int smem_capacity = ArchSpec::smem_capacity;
int best_num_stages = 0;
SharedMemoryConfig best_smem_config;
for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) {
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
continue;
best_smem_config = get_smem_config<ArchSpec>(kernel_type,
m, n, k,
best_block_m, best_block_n, block_k,
major_a, major_b,
ab_dtype, cd_dtype,
num_stages, best_multicast_config);
if (best_smem_config.smem_size <= smem_capacity) {
best_num_stages = num_stages;
break;
}
}
DG_HOST_ASSERT(best_num_stages != 0);
// Recompute the minimal number of SMs required
// NOTES: less L2 cache usage and less GPU frequency drop
int num_min_sms = num_sms;
if (ArchSpec::should_minimize_num_sms()) {
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves);
num_min_sms = align(num_min_sms, best_multicast_config.num_multicast);
DG_HOST_ASSERT(num_min_sms <= num_sms);
}
const auto& config = GemmConfig {
.gemm_type = gemm_type,
.kernel_type = kernel_type,
.ab_dtype = ab_dtype,
.cd_dtype = cd_dtype,
.major_a = major_a,
.major_b = major_b,
.with_accumulation = with_accumulation,
.block_m = best_block_m,
.block_n = best_block_n,
.block_k = block_k,
.num_stages = best_num_stages,
.num_last_stages = ceil_div(k, block_k) % best_num_stages,
.num_sms = num_min_sms,
.multicast_config = best_multicast_config,
// ReSharper disable once CppLocalVariableMightNotBeInitialized
.smem_config = best_smem_config,
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n)
};
// Print configs for the first time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b,
ab_dtype, cd_dtype, with_accumulation, num_sms);
static std::set<decltype(key)> printed;
if (not printed.contains(key)) {
printf("Gemm type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, "
"A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, "
"SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, "
"SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, "
"swizzle B: %d, swizzle CD: %d, threads: %d\n",
static_cast<int>(gemm_type), static_cast<int>(kernel_type), m, n, k, num_groups,
static_cast<int>(major_a), static_cast<int>(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype),
static_cast<int>(with_accumulation), num_sms, best_block_m, best_block_n, block_k,
best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast,
static_cast<int>(best_multicast_config.is_multicast_on_a),
best_smem_config.smem_size, best_smem_config.swizzle_a_mode, best_smem_config.swizzle_b_mode,
best_smem_config.swizzle_cd_mode, config.thread_config.num_threads);
printed.insert(key);
}
}
return config;
}
} // namespace deep_gemm

View File

@@ -0,0 +1,144 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.hpp>
#include "common.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
struct SM100ArchSpec {
static constexpr int smem_capacity = 232448;
static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) {
return block_m / (config.is_multicast_on_a ? config.num_multicast : 1);
}
static int get_ab_load_block_n(const MulticastConfig& config, const int& block_n) {
return block_n / (config.is_multicast_on_a ? 1 : config.num_multicast);
}
static int get_cd_store_block_m(const int& block_m) {
constexpr int layout_ad_m = 128;
return std::min(block_m, layout_ad_m);
}
static int get_cd_store_block_n(const int& block_n) {
return block_n;
}
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) {
constexpr int num_utccp_aligned_elems = 128;
DG_HOST_ASSERT(block_m % num_utccp_aligned_elems == 0);
switch (ab_dtype) {
case torch::kBFloat16: return {0, 0};
case torch::kFloat8_e4m3fn: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)};
default: DG_HOST_UNREACHABLE("Unknown dtype");
}
}
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& block_m, const int& block_n) {
// Layout A/D does not support `block_m == 64` and `block_n % 16 != 0`
if (block_m == 64 or block_n % 16 != 0)
return false;
// Performance is lower with 1D1D and `block_m == 256`
if (kernel_type == KernelType::Kernel1D1D and major_b == cute::UMMA::Major::K and block_m != 128)
return false;
// 1D2D kernels' maximum block N is 128
// 1D2D kernels require more friendly block Ns
if (kernel_type == KernelType::Kernel1D2D and (block_n > 128 or 128 % block_n != 0))
return false;
// Check tensor memory validity
int sf_block_m = 0, sf_block_n = 0;
if (kernel_type == KernelType::Kernel1D1D) {
const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype);
sf_block_m = sf_block_m_, sf_block_n = sf_block_n_;
}
if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512)
return false;
// NOTES: when B is MN-major, we restrict `block_n` to multiples of 64,
// since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA
return major_b == cute::UMMA::Major::K or block_n % 64 == 0;
}
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& num_stages,
const int& block_m, const int& block_n, const int& block_k) {
return true;
}
static bool should_minimize_num_sms() {
return false;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
// TODO: support other layouts
return {
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
false,
};
}
static ThreadConfig get_thread_config(const KernelType& kernel_type,
const int& block_m, const int& block_n) {
return ThreadConfig::sm100(128, kernel_type == KernelType::Kernel1D1D ? 128 : block_m);
}
static int get_smem_cd_size(const KernelType& kernel_type,
const int& block_m, const int& block_n,
const int& swizzle_cd_mode,
const at::ScalarType& cd_dtype) {
constexpr static int layout_ad_m = 128;
return (kernel_type == KernelType::Kernel1D1D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2;
}
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
const int& block_m, const int& block_n, const int& block_k,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
if (ab_dtype == torch::kBFloat16)
return {0, 0};
int smem_sfa_per_stage = 0;
int smem_sfb_per_stage = 0;
if (kernel_type == KernelType::Kernel1D1D) {
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype);
smem_sfa_per_stage = sf_block_m * 4;
smem_sfb_per_stage = sf_block_n * 4;
} else {
smem_sfa_per_stage = block_m * 4;
smem_sfb_per_stage = 0;
}
return {smem_sfa_per_stage, smem_sfb_per_stage};
}
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
return 0;
}
static int get_barrier_smem_size(const int& num_stages) {
// TODO: remove SF barriers for BF16 GEMMs
// TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers
// NOTES: 1D2D kernel will not use the with-SF full barriers
// NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
return num_stages * 8 * 3 + 2 * 8 * 2;
}
static int get_tmem_ptr_smem_size() {
return 4;
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,115 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.hpp>
#include "common.hpp"
namespace deep_gemm {
struct SM90ArchSpec {
static constexpr int smem_capacity = 232448;
static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) {
return block_m;
}
static int get_ab_load_block_n(const MulticastConfig& multicast_config, const int& block_n) {
return block_n;
}
static int get_cd_store_block_m(const int& block_m) {
return block_m;
}
static int get_cd_store_block_n(const int& block_n) {
return block_n;
}
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& block_m, const int& block_n) {
// FP32 output does not support `block_m == 256`
if (cd_dtype == at::kFloat and block_m == 256)
return false;
// Must be some fixed block N selections
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 or block_n != 152))
return false;
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 or block_n != 160))
return false;
// Avoid bank conflicts for FP32 output
if (cd_dtype == torch::kFloat and block_n % 16 == 0)
return false;
// The block sizes cannot be too large (for enough registers), so at least one dim less than 128
return block_m <= 128 or block_n <= 128;
}
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& num_stages,
const int& block_m, const int& block_n, const int& block_k) {
// Unrolling both stages and `num_former_iters` will cause large code size
if (ab_dtype == torch::kFloat8_e4m3fn and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4)
return num_stages <= 4;
return true;
}
static bool should_minimize_num_sms() {
return true;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
return {
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked,
};
}
static ThreadConfig get_thread_config(const KernelType& kernel_type,
const int& block_m, const int& block_n) {
return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128);
}
static int get_smem_cd_size(const KernelType& kernel_type,
const int& block_m, const int& block_n,
const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) {
return block_m * block_n * static_cast<int>(c10::elementSize(cd_dtype));
}
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
const int& block_m, const int& block_n, const int& block_k,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
if (ab_dtype == torch::kBFloat16)
return {0, 0};
int smem_sfa_per_stage = block_m * static_cast<int>(sizeof(float));
int smem_sfb_per_stage = 0;
// TODO: figure out here
if (kernel_type == KernelType::Kernel1D1D)
smem_sfb_per_stage = align(block_n * 4, block_k);
return {smem_sfa_per_stage, smem_sfb_per_stage};
}
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2;
return align<int>(ceil_div(k, block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);
}
static int get_barrier_smem_size(const int& num_stages) {
// For 1D1D kernels, there is an extra barrier for accumulation
return (num_stages + 1) * 8 * 2;
}
static int get_tmem_ptr_smem_size() {
return 0;
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,173 @@
#pragma once
#include <cuda.h>
#include <torch/python.h>
#include "../../utils/math.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
static std::pair<int, int> get_inner_outer_dims(const cute::UMMA::Major& major, const int& k, const int& mn) {
return major == cute::UMMA::Major::K ? std::make_pair(k, mn) : std::make_pair(mn, k);
}
static int get_non_contiguous_dim(const cute::UMMA::Major& major) {
return major == cute::UMMA::Major::K ? -2 : -1;
}
static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) {
for (const char& c: compiled_dims) {
if (name == c)
return dim;
}
return 0;
}
static std::string to_string(const cute::UMMA::Major& major) {
switch (major) {
case cute::UMMA::Major::K: return "cute::UMMA::Major::K";
case cute::UMMA::Major::MN: return "cute::UMMA::Major::MN";
}
DG_HOST_UNREACHABLE("Unknown major");
}
static std::string to_string(const GemmType& type) {
switch (type) {
case GemmType::Normal: return "GemmType::Normal";
case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous";
case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked";
case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous";
}
DG_HOST_UNREACHABLE("Unknown GEMM type");
}
static std::string to_string(const at::ScalarType& dtype) {
switch (dtype) {
case torch::kInt: return "int";
case torch::kFloat: return "float";
case torch::kBFloat16: return "cutlass::bfloat16_t";
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) {
switch (dtype) {
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) {
switch (mode) {
case 0: return CU_TENSOR_MAP_SWIZZLE_NONE;
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
case 128: return CU_TENSOR_MAP_SWIZZLE_128B;
default: DG_HOST_UNREACHABLE("Unsupported swizzling mode");
}
}
static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
int gmem_inner_dim, int gmem_outer_dim,
int smem_inner_dim, int smem_outer_dim,
const int& gmem_outer_stride,
const int& swizzle_mode) {
const auto& elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_inner_dim = swizzle_mode / elem_size;
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[2] = {static_cast<cuuint64_t>(gmem_inner_dim), static_cast<cuuint64_t>(gmem_outer_dim)};
const cuuint32_t smem_dims[2] = {static_cast<cuuint32_t>(smem_inner_dim), static_cast<cuuint32_t>(smem_outer_dim)};
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
const cuuint32_t elem_strides[2] = {1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n",
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
gmem_outer_stride, swizzle_mode, elem_size);
}
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()),
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensor_map;
}
static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
const int& shape_m, const int& shape_k,
const int& block_m, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
if (num_groups > 1)
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m);
return make_tma_2d_desc(t,
gmem_inner_dim, gmem_outer_dim,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode);
}
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
const int& shape_n, const int& shape_k,
const int& block_n, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
// `num_groups` is always applied into the outer dimensions
return make_tma_2d_desc(t,
gmem_inner_dim, gmem_outer_dim * num_groups,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode);
}
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
const int& shape_m, const int& shape_n,
const int& block_m, const int& block_n,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
return make_tma_2d_desc(t,
shape_n, shape_m * num_groups,
block_n, block_m,
outer_stride,
swizzle_mode);
}
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
int shape_mn, int shape_k,
const int& block_mn, const int& block_k,
const int& num_groups,
const int& swizzle_mode) {
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
// TODO: maybe swizzle SF as well
DG_HOST_ASSERT(swizzle_mode == 0);
shape_mn = get_tma_aligned_size(shape_mn, static_cast<int>(t.element_size()));
return make_tma_2d_desc(t,
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
block_mn, 1,
shape_mn,
swizzle_mode);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,351 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8Gemm1D1DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmConfig gemm_config;
LaunchArgs launch_args;
void* grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_c;
CUtensorMap tensor_map_d;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d1d_impl<
{}, {},
{}, {}, {},
{}, {}, {},
{},
{}, {}, {},
{}, {},
{}, {},
{}, {},
{},
{}, {}
>);
}};
)",
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
to_string(args.gemm_config.gemm_type),
args.gemm_config.with_accumulation,
to_string(args.gemm_config.cd_dtype));
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
args.grouped_layout, args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_c, args.tensor_map_d));
}
};
static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto& cd = c.value_or(d);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(cd.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, 1, 0);
// Duplicate the accumulator if necessary
if (c.has_value()) {
if (c->data_ptr() == d.data_ptr()) {
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
} else {
// ReSharper disable once CppExpressionWithoutSideEffects
d.copy_(c.value());
}
}
// Launch
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_c = tensor_map_c,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, num_groups, 0);
// Launch kernel
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_c = tensor_map_d,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_fp8_m_grouped_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D1D,
expected_m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, num_groups, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, num_groups, 0);
// Launch kernel
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_c = tensor_map_d,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
}
static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
int sum_k = 0, sum_sf_k = 0;
for (const auto& k: ks) {
sum_k += k, sum_sf_k += ceil_div(k, 512);
DG_HOST_ASSERT(k % 128 == 0);
}
const auto& num_groups = static_cast<int>(ks.size());
// Get config using max K for better performance
const auto& max_k = *std::ranges::max_element(ks);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Create tensor descriptors
const auto& cd = c.value_or(d);
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(0)), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(0)), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(1)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(cd.stride(1)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512,
config.block_m, config.block_k, num_groups, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512,
config.block_n, config.block_k, num_groups, 0);
// Duplicate the accumulator if necessary
if (c.has_value()) {
DG_HOST_ASSERT(c->data_ptr() == d.data_ptr());
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
}
// Launch kernel
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_c = tensor_map_c,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,242 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100FP8Gemm1D2DRuntime final: public LaunchRuntime<SM100FP8Gemm1D2DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmConfig gemm_config;
LaunchArgs launch_args;
void *sfb, *grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
CUtensorMap tensor_map_sfa;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d2d_impl<
{}, {},
{}, {}, {},
{}, {}, {},
{},
{}, {}, {},
{}, {},
{}, {},
{}, {},
{}, {}
>);
}};
)",
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
to_string(args.gemm_config.gemm_type),
to_string(args.gemm_config.cd_dtype));
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
args.sfb, args.grouped_layout,
args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_d, args.tensor_map_sfa));
}
};
static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(not c.has_value());
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::Kernel1D2D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
// Launch
const SM100FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d2d", code);
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
// Launch
const SM100FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d2d", code);
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm100_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
expected_m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, num_groups, 0);
// Launch
const SM100FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d2d", code);
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,255 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmConfig gemm_config;
LaunchArgs launch_args;
void *sfb, *grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
CUtensorMap tensor_map_sfa;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d2d_impl<
{}, {}, {},
{},
{}, {}, {},
{},
{}, {},
{}, {},
{}, {},
{}
>);
}};
)",
// TODO: add CD dtype
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.num_groups,
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
to_string(args.gemm_config.gemm_type));
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
args.sfb, args.grouped_layout,
args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_d, args.tensor_map_sfa));
}
};
static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::Kernel1D2D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm90_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
expected_m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, num_groups, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,199 @@
#pragma once
#include <torch/python.h>
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../../utils/layout.hpp"
namespace deep_gemm {
class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime<TransposeAndPackFP32IntoUE8M0Runtime> {
public:
struct Args {
int mn, sf_k;
int block_mn;
void *sf, *out;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&transpose_and_pack_fp32_into_ue8m0<
{}, {}, {}
>);
}};
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, args.sf, args.out, static_cast<uint32_t>(args.mn)));
}
};
class PackFP32IntoUE8M0Runtime final: public LaunchRuntime<PackFP32IntoUE8M0Runtime> {
public:
struct Args {
int num_groups, mn, sf_k, packed_sf_k;
int block_mn, block_packed_sf_k;
void *sf, *out, *ks;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&pack_fp32_into_ue8m0<
{}, {}, {}, {}
>);
}};
)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k);
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k));
}
};
static std::tuple<int, int, int, int, int, torch::Tensor> preprocess_sf(const torch::Tensor& sf) {
// NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
const auto& dim = sf.dim();
DG_HOST_ASSERT(dim == 2 or dim == 3);
DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat);
const auto& batched_sf = dim == 2 ? sf.unsqueeze(0) : sf;
const auto& [num_groups, mn, sf_k] = get_shape<3>(batched_sf);
const auto& tma_aligned_mn = get_tma_aligned_size(mn, static_cast<int>(sf.element_size()));
return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf};
}
static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
// The last kernel already gives a column-major TMA aligned layout
if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn)
return (dim == 2) ? batched_sf.squeeze(0) : batched_sf;
// Normal layout requires transposing
auto aligned_sf = torch::empty_strided({num_groups, tma_aligned_mn, sf_k}, {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, batched_sf.options());
aligned_sf = aligned_sf.slice(1, 0, mn).copy_(batched_sf);
return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf;
}
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) {
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
const auto& packed_sf_k = ceil_div(sf_k, 4);
const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k},
{packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn},
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
DG_HOST_ASSERT(num_groups == 1 or (mn * sf_k) % 4 == 0);
// Launch the kernel
if (batched_sf.is_contiguous()) {
constexpr int block_mn = 48;
constexpr int num_threads = 512;
const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = {
.mn = mn,
.sf_k = sf_k,
.block_mn = block_mn,
.sf = batched_sf.data_ptr(),
.out = out.data_ptr(),
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4)
};
const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
} else {
DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1);
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);
constexpr int block_mn = 128;
constexpr int block_packed_sf_k = 16;
constexpr int num_threads = 512;
const PackFP32IntoUE8M0Runtime::Args& args = {
.num_groups = 1,
.mn = mn,
.sf_k = sf_k,
.packed_sf_k = packed_sf_k,
.block_mn = block_mn,
.block_packed_sf_k = block_packed_sf_k,
.sf = batched_sf.data_ptr(),
.out = out.data_ptr(),
.ks = nullptr,
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
};
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
PackFP32IntoUE8M0Runtime::launch(runtime, args);
}
return (dim == 2) ? out.squeeze(0) : out;
}
static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf,
const torch::Tensor& ks_tensor,
const std::vector<int>& ks) {
const auto& [sf_k, mn] = get_shape<2>(sf);
const auto& num_groups = static_cast<int>(ks.size());
int ref_sf_k = 0, packed_sf_k = 0;
for (const auto& k: ks)
ref_sf_k += ceil_div(k, 128), packed_sf_k += ceil_div(k, 512);
DG_HOST_ASSERT(sf.is_contiguous());
DG_HOST_ASSERT(ref_sf_k == sf_k);
DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0);
const auto& out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt));
constexpr int block_mn = 128;
constexpr int block_packed_sf_k = 16;
constexpr int num_threads = 512;
const PackFP32IntoUE8M0Runtime::Args& args = {
.num_groups = num_groups,
.mn = mn,
.sf_k = sf_k,
.packed_sf_k = packed_sf_k,
.block_mn = block_mn,
.block_packed_sf_k = block_packed_sf_k,
.sf = sf.data_ptr(),
.out = out.data_ptr(),
.ks = ks_tensor.data_ptr(),
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
};
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
PackFP32IntoUE8M0Runtime::launch(runtime, args);
return out;
}
} // namespace deep_gemm