Make various updates and fixes: (#164)
- Add BF16 support for SM90 and SM100 - Refactor Python APIs - Other fixes and code refactoring
This commit is contained in:
@@ -11,16 +11,22 @@ enum class KGroupedIndexType {
|
||||
SF_K,
|
||||
};
|
||||
|
||||
template <uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool isMulticastOnA>
|
||||
template <GemmType kGemmType, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool kIsMulticastOnA>
|
||||
static constexpr uint32_t get_num_1d_blocks_per_group() {
|
||||
// Select the best from candidates
|
||||
uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
|
||||
for (const auto& candidate: {8u, 16u}) {
|
||||
const auto& usage = isMulticastOnA ?
|
||||
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
|
||||
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
|
||||
if (usage < min_usage)
|
||||
min_usage = usage, num_best_blocks = candidate;
|
||||
if constexpr (kGemmType == GemmType::MGroupedContiguous or
|
||||
kGemmType == GemmType::MGroupedMasked) {
|
||||
// For grouped GEMMs, let weights always stay in the L2 cache and read activations by once
|
||||
num_best_blocks = kNumSMs;
|
||||
} else {
|
||||
for (const auto& candidate: {8u, 16u}) {
|
||||
const auto& usage = kIsMulticastOnA ?
|
||||
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
|
||||
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
|
||||
if (usage < min_usage)
|
||||
min_usage = usage, num_best_blocks = candidate;
|
||||
}
|
||||
}
|
||||
return num_best_blocks;
|
||||
}
|
||||
@@ -32,7 +38,7 @@ template <GemmType kGemmType,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
|
||||
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
|
||||
|
||||
@@ -48,7 +48,18 @@ struct FP8MMASelector {
|
||||
if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN();
|
||||
}
|
||||
|
||||
static constexpr auto select_type() {
|
||||
@@ -58,6 +69,71 @@ struct FP8MMASelector {
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
template <int N_, typename MMA>
|
||||
struct BF16MMA {
|
||||
|
||||
template <size_t ...Idx>
|
||||
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
|
||||
using namespace cute::SM90::GMMA;
|
||||
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = N_;
|
||||
static constexpr int K = 16;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct BF16MMASelector {
|
||||
|
||||
static constexpr auto select_mma() {
|
||||
using namespace cute::SM90::GMMA;
|
||||
if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
}
|
||||
|
||||
static constexpr auto select_type() {
|
||||
return BF16MMA<N, decltype(select_mma())>();
|
||||
}
|
||||
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
|
||||
@@ -144,4 +144,22 @@ __device__ __forceinline__ void prefetch_l1(void *ptr) {
|
||||
asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
|
||||
}
|
||||
|
||||
template <uint32_t kNumBytes>
|
||||
struct Vectorized {
|
||||
static auto zeros() {
|
||||
// TODO: add `ulonglong4` for SM100 once `__ldg` support this
|
||||
if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) {
|
||||
return make_uint4(0, 0, 0, 0);
|
||||
} else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) {
|
||||
return make_uint2(0, 0);
|
||||
} else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) {
|
||||
return 0;
|
||||
} else {
|
||||
DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization");
|
||||
}
|
||||
}
|
||||
|
||||
using vec_t = decltype(zeros());
|
||||
};
|
||||
|
||||
} // namespace `deep_gemm`
|
||||
|
||||
@@ -1,3 +1,498 @@
|
||||
#pragma once
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
// TODO: add implement
|
||||
#include <cutlass/arch/barrier.h>
|
||||
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm100_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm100;
|
||||
|
||||
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
|
||||
uint64_t kTensorCoreUtilControl>
|
||||
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
||||
sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_c,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// GEMM with accumulation must have FP32 output
|
||||
if constexpr (kWithAccumulation)
|
||||
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t LAYOUT_AD_M = 128;
|
||||
constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M;
|
||||
constexpr uint32_t kNumTMAStoreStages = 2;
|
||||
DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K");
|
||||
DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
|
||||
|
||||
// Overwrite shape constants if the compiler gives
|
||||
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
||||
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
||||
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
||||
|
||||
// Utils
|
||||
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
||||
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
const auto lane_idx = get_lane_idx();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
|
||||
// 2-CTA MMA
|
||||
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
||||
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
||||
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
|
||||
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
|
||||
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D");
|
||||
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
|
||||
|
||||
// Share memory sizes
|
||||
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
|
||||
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
||||
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
|
||||
|
||||
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
|
||||
// TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
|
||||
constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2;
|
||||
|
||||
// Real tensor memory size and offsets
|
||||
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
|
||||
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == 0) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
if constexpr (kWithAccumulation)
|
||||
cute::prefetch_tma_descriptor(&tensor_map_c);
|
||||
}
|
||||
|
||||
// Data on shared memory (layout as ordered below)
|
||||
cd_dtype_t* smem_cd[kNumTMAStoreStages];
|
||||
cutlass::bfloat16_t* smem_a[kNumStages];
|
||||
cutlass::bfloat16_t* smem_b[kNumStages];
|
||||
|
||||
// Fill D/A/B pointers
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
|
||||
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
||||
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
||||
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
||||
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
||||
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
||||
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
|
||||
|
||||
// Fill the tensor memory pointer
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
// Initialize barriers
|
||||
if (threadIdx.x == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
// Arrive only at the leader CTA
|
||||
full_barriers[i]->init(kNumMulticast);
|
||||
// Arrive at all CTAs
|
||||
empty_barriers[i]->init(1);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
|
||||
// Arrive at all CTAs
|
||||
tmem_full_barriers[i]->init(1);
|
||||
// Arrive only at the leader CTA
|
||||
tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
|
||||
// Allocate tensor memory
|
||||
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
uint32_t phase = 0;
|
||||
auto launch_k_iterations = [&](const auto& func) {
|
||||
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
|
||||
const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K);
|
||||
const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages;
|
||||
|
||||
// TODO: refactor here
|
||||
if (num_last_stages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1)
|
||||
func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages);
|
||||
} else {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1)
|
||||
func(k_iter, DivisibleK{}, false, num_last_stages);
|
||||
func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1;
|
||||
}
|
||||
};
|
||||
|
||||
auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) {
|
||||
DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2,
|
||||
"Too many epilogue stages, please modify the Python heuristic as well");
|
||||
accum_stage_idx == 0 ? func(0) : func(1);
|
||||
};
|
||||
|
||||
// Dispatch warps into different roles
|
||||
if (warp_idx == 0) {
|
||||
// TMA load warp
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait(phase ^ 1);
|
||||
|
||||
// Compute offsets
|
||||
// NOTES: the group is always concatenated with the outer dimension
|
||||
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
|
||||
shape_m, BLOCK_M, m_block_idx);
|
||||
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
|
||||
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
||||
|
||||
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
||||
// And for all m-grouped GEMMs, A must be K-majored
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
||||
uint32_t k_block_idx = k_iter * kNumStages + s;
|
||||
uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
|
||||
// Add 2 CTA offsets
|
||||
if constexpr (kNumMulticast > 1) {
|
||||
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
|
||||
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
||||
}
|
||||
|
||||
// Issue TMAs
|
||||
if (cute::elect_one_sync()) {
|
||||
if constexpr (kMajorA == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx);
|
||||
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
|
||||
}
|
||||
// Arrive at full barriers
|
||||
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
|
||||
if (not is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive(0u);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait(phase ^ 1);
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive();
|
||||
if (not is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive(0u);
|
||||
}
|
||||
});
|
||||
}
|
||||
} else if (warp_idx == 1 and is_leader_cta) {
|
||||
// MMA issue warp
|
||||
// NOTES: only the leader CTA will do this
|
||||
// Make instruction descriptor
|
||||
// TODO: refactor `UMMA_M` calculation
|
||||
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
|
||||
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
|
||||
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
||||
|
||||
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
||||
auto a_desc = make_umma_desc<kMajorA, BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
||||
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
||||
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
||||
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
||||
|
||||
// Checks for MMA instructions
|
||||
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
||||
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
||||
"Invalid MMA instruction shape");
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
|
||||
// Wait tensor memory empty barrier arrival
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) {
|
||||
auto umma_arrive = [](const uint64_t* barrier) {
|
||||
if constexpr (kNumMulticast == 1) {
|
||||
cutlass::arch::umma_arrive(barrier);
|
||||
} else {
|
||||
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
||||
}
|
||||
};
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[s]));
|
||||
|
||||
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
||||
if (do_tmem_full_arrive)
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA arrival
|
||||
full_barriers[s]->wait(phase);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Let tensor cores relax for lower possibility of frequency drop
|
||||
DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
|
||||
if constexpr (kTensorCoreUtilControl < 100) {
|
||||
constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull;
|
||||
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
|
||||
const auto& start_clock = clock64();
|
||||
if (cute::elect_one_sync())
|
||||
while (clock64() - start_clock < kNumDummyCycles) {}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
using cute_mma_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_MMA_F16BF16_SS <cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>,
|
||||
cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>>;
|
||||
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
||||
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s);
|
||||
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s);
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
|
||||
cute_mma_t::fma(a_desc, b_desc,
|
||||
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
||||
k_iter > 0 or s > 0 or k > 0,
|
||||
runtime_instr_desc);
|
||||
}
|
||||
}
|
||||
|
||||
// Commit to the mbarrier object
|
||||
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
||||
empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait(phase);
|
||||
empty_barrier_arrive(s, false);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
|
||||
// Epilogue warp groups
|
||||
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
|
||||
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
|
||||
|
||||
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
||||
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
|
||||
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
||||
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
|
||||
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
||||
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
|
||||
// Flush TMA stores
|
||||
// NOTES: for the first store, we have to flush all previous TMA,
|
||||
// as we don't share pipeline stages between two blocks
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
|
||||
// Wait UMMA arrival
|
||||
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Load from tensor memory into registers, and write shared memory with STSM
|
||||
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
|
||||
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
||||
|
||||
// Iterate over M waves
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
// Issue every swizzled atom and pipeline STSM and TMA store
|
||||
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStores; ++ s) {
|
||||
// Wait shared memory to be released
|
||||
const uint32_t iter_idx = w * kNumStores + s;
|
||||
if (iter_idx >= kNumTMAStoreStages) {
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
}
|
||||
|
||||
// The pipeline stage
|
||||
const auto tma_stage_idx = iter_idx % kNumTMAStoreStages;
|
||||
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
|
||||
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
||||
|
||||
// Store into shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
||||
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
||||
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
||||
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleCDMode / 16);
|
||||
|
||||
// Source and destination memory address
|
||||
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
|
||||
w * BLOCK_N + // Wave offset
|
||||
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
||||
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
||||
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
|
||||
// Load from tensor memory, store into shared memory
|
||||
uint32_t values[kNumElemsPerBankGroup];
|
||||
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
||||
// For FP32 output, read and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
||||
} else {
|
||||
// For BF16 output, read, cast and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3],
|
||||
values[4], values[5], values[6], values[7]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr,
|
||||
cast_into_bf16_and_pack(values[0], values[1]),
|
||||
cast_into_bf16_and_pack(values[2], values[3]),
|
||||
cast_into_bf16_and_pack(values[4], values[5]),
|
||||
cast_into_bf16_and_pack(values[6], values[7]));
|
||||
}
|
||||
}
|
||||
|
||||
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
||||
// NOTES: only the last stage needs to do this
|
||||
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
||||
tcgen05_before_thread_sync();
|
||||
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
if (epilogue_thread_idx == 0) {
|
||||
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
||||
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
||||
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
|
||||
// TODO: do we actually need this?
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
|
||||
// Deallocate tensor memory by warp 1
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
// TODO: do we need 2 SM allocation?
|
||||
if (epilogue_warp_idx == 1)
|
||||
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
// To safely deconstruct all barriers, we need a cluster sync
|
||||
// TODO: optimize it by another round of barrier waits
|
||||
if constexpr (kNumMulticast > 1)
|
||||
cute::cluster_sync();
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
|
||||
#endif
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@@ -136,7 +136,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
// Arrive at all CTAs
|
||||
full_barriers[i]->init(1);
|
||||
full_barriers[i]->init(kNumMulticast);
|
||||
empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads / 32);
|
||||
}
|
||||
#pragma unroll
|
||||
@@ -241,6 +241,8 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE;
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
|
||||
if (not is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive(0u);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
@@ -249,6 +251,8 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive();
|
||||
if (not is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive(0u);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,3 +1,343 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: add implement
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kSwizzleDMode,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
uint32_t kNumSMs, GemmType kGemmType>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
||||
sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Types
|
||||
using WGMMA = typename BF16MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
||||
|
||||
// Overwrite shape constants if the compiler gives
|
||||
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
||||
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
||||
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_idx();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_bfloat16* smem_a[kNumStages];
|
||||
__nv_bfloat16* smem_b[kNumStages];
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
Barrier* empty_barriers[kNumStages];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_bfloat16*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_bfloat16*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [=](const auto& func) {
|
||||
if constexpr (kNumLastStages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(num_iterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr uint32_t kNumTMARegisters = 48;
|
||||
constexpr uint32_t kNumMathRegisters = 224;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
||||
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
|
||||
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
||||
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
|
||||
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
|
||||
num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
|
||||
num_tma_multicast_b);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
||||
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2);
|
||||
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
|
||||
float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](uint32_t s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
|
||||
}
|
||||
};
|
||||
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
|
||||
// TODO: remove some useless computation for unaligned Ms
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, shifted_accum, 1);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival at the last warpgroup wave
|
||||
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
|
||||
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
||||
DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
|
||||
DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
|
||||
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
|
||||
"Unaligned TMA store or too many TMA store instructions");
|
||||
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
|
||||
|
||||
// Wait last TMA store to be finished
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Write back to shared memory using STSM and issue TMA stores
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// Swizzle or padding into the correct address
|
||||
uint8_t* smem_ptr = nullptr;
|
||||
if constexpr (kSwizzleDMode > 0) {
|
||||
// Calculate the swizzling atom offset and in-atom offset
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
||||
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
|
||||
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
||||
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleDMode / 16);
|
||||
|
||||
// Add back into the base pointer
|
||||
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
||||
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
|
||||
m_offset * kSwizzleDMode + // Wave offset
|
||||
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
} else {
|
||||
// No swizzling, just padding
|
||||
// TODO: support more cases
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
|
||||
}
|
||||
|
||||
// NOTES: only 16 lanes' addresses are used
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
||||
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
||||
smem_ptr
|
||||
);
|
||||
}
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
// TODO: compatible with FP32 output
|
||||
constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked;
|
||||
DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
||||
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
||||
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||
scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@@ -175,7 +175,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
|
||||
|
||||
@@ -4,6 +4,45 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K,
|
||||
uint32_t PADDED_SF_K = SF_K + (1 - (SF_K % 2))>
|
||||
__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
|
||||
typedef typename Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
|
||||
constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float);
|
||||
constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec;
|
||||
|
||||
// Shapes and strides
|
||||
extern __shared__ float smem_buffer[];
|
||||
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(float));
|
||||
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
||||
const auto tma_aligned_mn = align<uint32_t>(mn, kNumTMAAlignedElems);
|
||||
|
||||
// Shift into the block
|
||||
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
|
||||
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * SF_K;
|
||||
const auto& local_sf = reinterpret_cast<const in_vec_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
|
||||
|
||||
// Load
|
||||
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) {
|
||||
auto in_vec = __ldg(local_sf + i);
|
||||
const auto& in_values = reinterpret_cast<float*>(&in_vec);
|
||||
|
||||
const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumElemsPerVec; ++ j)
|
||||
smem_buffer[row * PADDED_SF_K + col + j] = in_values[j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Store
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) {
|
||||
const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn;
|
||||
const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx;
|
||||
out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// NOTES: the two kernels below always pack the K dimension
|
||||
|
||||
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
|
||||
|
||||
Reference in New Issue
Block a user