Make various updates and fixes (#198)

This commit is contained in:
Ray Wang
2025-09-25 16:19:07 +08:00
committed by GitHub
parent 79f48ee15a
commit 3f71de7aa9
45 changed files with 3281 additions and 1060 deletions

View File

@@ -25,15 +25,22 @@ from deep_gemm_cpp import (
# FP8 GEMMs
fp8_gemm_nt, fp8_gemm_nn,
fp8_gemm_tn, fp8_gemm_tt,
fp8_gemm_nt_skip_head_mid,
m_grouped_fp8_gemm_nt_contiguous,
m_grouped_fp8_gemm_nn_contiguous,
m_grouped_fp8_gemm_nt_masked,
k_grouped_fp8_gemm_nt_contiguous,
k_grouped_fp8_gemm_tn_contiguous,
# BF16 GEMMs
bf16_gemm_nt, bf16_gemm_nn,
bf16_gemm_tn, bf16_gemm_tt,
m_grouped_bf16_gemm_nt_contiguous,
m_grouped_bf16_gemm_nt_masked,
# cuBLASLt GEMMs
cublaslt_gemm_nt, cublaslt_gemm_nn,
cublaslt_gemm_tn, cublaslt_gemm_tt,
# Einsum kernels
einsum,
# Layout kernels
transform_sf_into_required_layout
)

View File

@@ -0,0 +1,27 @@
#pragma once
#include <deep_gemm/common/types.hpp>
#include <deep_gemm/common/utils.cuh>
namespace deep_gemm {
struct EpilogueIdentity {
template <uint32_t STORE_BLOCK_N>
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
return n_idx;
}
};
template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
struct EpilogueHeadSplits: EpilogueIdentity {
template <uint32_t STORE_BLOCK_N>
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0
and kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@@ -0,0 +1,44 @@
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda/std/cstdint>
#include <cuda/std/utility>
#include <deep_gemm/common/utils.cuh>
// Operation functors
template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
template <typename T> struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } };
template <typename T> struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } };
// Unified reduction function
template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
__forceinline__ __device__ T warp_reduce(T value, Op op) {
DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
"Invalid number of lanes");
constexpr uint32_t mask = 0xffffffff;
if constexpr (kIntergroupReduce) {
if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
} else {
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
}
return value;
}
// Convenience aliases
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_sum(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
}

View File

@@ -22,7 +22,6 @@ static constexpr uint32_t get_num_1d_blocks_per_group() {
if (usage < min_usage)
min_usage = usage, num_best_blocks = candidate;
}
return num_best_blocks;
}
@@ -33,6 +32,7 @@ template <GemmType kGemmType,
uint32_t kNumGroups,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
uint32_t SF_K_ALIGNMENT = 512u, // for k-grouped GEMM only: 128 (SM90 float SF) or 512 (SM100 UE8M0 SF)
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
struct Scheduler {
int current_iter = -1;
@@ -48,30 +48,40 @@ struct Scheduler {
// For grouped GEMM
int* grouped_layout;
uint32_t current_group_idx;
uint32_t current_group_idx = 0;
// Only used for masked layout
uint32_t current_m_cumsum;
uint32_t current_m_cumsum = 0;
// Only used for k-grouped layout
uint32_t current_shape_k, current_num_valid_groups, current_k_cumsum, current_sf_k_cumsum;
uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
uint32_t next_group_idx, next_shape_k;
// Only used for k-grouped gemm
__device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const {
for (; group_idx < kNumGroups; ++ group_idx) {
shape_k = __ldg(grouped_layout + group_idx);
if (shape_k > 0)
break;
}
}
// ReSharper disable once CppPossiblyUninitializedMember
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n,
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k,
int* grouped_layout = nullptr) {
num_m_blocks = ceil_div(shape_m, BLOCK_M);
num_n_blocks = ceil_div(shape_n, BLOCK_N);
current_shape_k = shape_k;
if constexpr (kGemmType == GemmType::Normal) {
num_blocks = num_m_blocks * num_n_blocks;
} else if (kGemmType == GemmType::MGroupedContiguous) {
num_blocks = num_m_blocks * num_n_blocks;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::MGroupedMasked) {
current_group_idx = current_m_cumsum = 0;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::KGroupedContiguous) {
current_group_idx = current_num_valid_groups = 0;
current_k_cumsum = current_sf_k_cumsum = 0;
current_shape_k = __ldg(grouped_layout + current_group_idx);
this->grouped_layout = grouped_layout;
get_next_k_group(current_group_idx, current_shape_k);
next_group_idx = current_group_idx + 1;
get_next_k_group(next_group_idx, next_shape_k);
}
}
@@ -165,17 +175,17 @@ struct Scheduler {
return false;
// Within current group
if (current_shape_k > 0 and next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
break;
// Move to check the next group
if (current_shape_k > 0) {
current_k_cumsum += current_shape_k;
current_sf_k_cumsum += ceil_div(current_shape_k, 512u);
current_num_valid_groups ++;
}
if ((++ current_group_idx) != kNumGroups)
current_shape_k = __ldg(grouped_layout + current_group_idx);
current_k_cumsum += current_shape_k;
current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT);
current_num_valid_groups ++;
current_group_idx = next_group_idx ++;
current_shape_k = next_shape_k;
get_next_k_group(next_group_idx, next_shape_k);
}
get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx);
@@ -197,7 +207,7 @@ struct Scheduler {
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
if (num_blocks_in_group == 1)
return false;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked) {
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::KGroupedContiguous) {
return true;
} else {
DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");

View File

@@ -79,12 +79,24 @@ void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_p
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
}
__device__ __forceinline__
static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) {
return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16;
}
// ReSharper disable once CppNotAllPathsReturnValue
template <uint32_t kSwizzleMode>
template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, bool kUseBase32, typename dtype_t>
constexpr static cute::UMMA::LayoutType to_umma_layout_type() {
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
kSwizzleMode == 32 or kSwizzleMode == 64 or
kSwizzleMode == 128, "Invalid swizzling mode");
// A special case
if constexpr ((cute::is_same_v<dtype_t, float> and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) {
DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base");
return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B;
}
// Normal cases
if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE;
if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE;
if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B;
@@ -104,10 +116,12 @@ uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, cons
return base + (((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, bool kUseBase32 = false, typename dtype_t>
__device__ __forceinline__
cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
const uint32_t stride_k = get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
const auto& layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
const auto& num_non_contiguous = 128 / get_atom_base(layout_type);
if constexpr (kMajorMode == cute::UMMA::Major::K) {
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
@@ -115,9 +129,9 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
// {SBO, LBO} means the byte stride between atoms on {MN, K}
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
const uint32_t stride_byte_offset = 8 * BLOCK_K * sizeof(dtype_t);
const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
const uint32_t leading_byte_offset = 0;
return make_smem_desc(to_umma_layout_type<kSwizzleMode>(),
return make_smem_desc(layout_type,
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
stride_byte_offset, leading_byte_offset);
} else {
@@ -132,11 +146,11 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
uint32_t stride_byte_offset = 8 * BLOCK_MN_ATOM * sizeof(dtype_t);
uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
if constexpr (kSwizzleMode == 16)
swap(stride_byte_offset, leading_byte_offset);
return make_smem_desc(to_umma_layout_type<kSwizzleMode>(),
return make_smem_desc(layout_type,
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
stride_byte_offset, leading_byte_offset);
}
@@ -166,4 +180,81 @@ __device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
// UMMA versions with relaxed assertions
struct SM100_MMA_F16BF16_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
struct SM100_MMA_F16BF16_2x1SM_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
struct SM100_MMA_MXF8F6F4_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
struct SM100_MMA_MXF8F6F4_2x1SM_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
} // namespace `deep_gemm::sm100`

View File

@@ -1,8 +1,12 @@
#pragma once
#include <cute/arch/copy_sm90_tma.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>
#include <deep_gemm/common/utils.cuh>
namespace deep_gemm::sm90 {
template <int N_, typename MMA>
@@ -29,6 +33,7 @@ struct FP8MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN();
@@ -93,6 +98,7 @@ struct BF16MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS<Major::K, Major::K>();
if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS<Major::K, Major::K>();
if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS<Major::K, Major::K>();
if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS<Major::K, Major::K>();
@@ -144,6 +150,24 @@ struct SM90_U32x2_STSM_N {
}
};
struct SM90_U32x2_LDSM_N {
__device__ __forceinline__ static void
copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) {
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst_0), "=r"(dst_1)
: "l"(smem_src));
}
};
struct SM90_U32x4_LDSM_N {
__device__ __forceinline__ static void
copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3)
: "l"(smem_src));
}
};
__forceinline__ __device__ void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
}
@@ -223,4 +247,37 @@ tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
}
}
__device__ __forceinline__ void
tma_3d_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& crd_2) {
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
cute::SM90_TMA_LOAD_3D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1, crd_2);
}
// Tensormap related
__device__ __forceinline__ void tensor_map_release_cta() {
asm volatile ("fence.proxy.tensormap::generic.release.cta;");
}
__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) {
auto gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory");
}
__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) {
auto smem_int_desc = static_cast<uint32_t>(__cvta_generic_to_shared(smem_desc));
const auto new_int64_addr = reinterpret_cast<uint64_t>(new_addr);
asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr));
}
__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) {
auto smem_int_desc = __cvta_generic_to_shared(smem_desc);
asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim));
#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 5)))
asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride));
#else
DG_STATIC_ASSERT(false, "Invalid CUDA version")
#endif
}
} // namespace `deep_gemm::sm90`

View File

@@ -104,6 +104,12 @@ __device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) {
return ret;
}
__device__ __forceinline__ float2 ld_shared(const float2* ptr) {
float2 ret;
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float4 ld_shared(const float4* ptr) {
float4 ret;
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr));
@@ -126,10 +132,18 @@ __device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
}
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y));
}
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) {
asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(ptr), "r"(x), "r"(y));
}
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(ptr), "r"(x), "r"(y), "r"(z), "r"(w));
}

View File

@@ -17,7 +17,7 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumStages,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
@@ -84,8 +84,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == 0) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
@@ -93,35 +92,31 @@ sm100_bf16_gemm_impl(int* grouped_layout,
cute::prefetch_tma_descriptor(&tensor_map_c);
}
// Data on shared memory (layout as ordered below)
cd_dtype_t* smem_cd[kNumTMAStoreStages];
cutlass::bfloat16_t* smem_a[kNumStages];
cutlass::bfloat16_t* smem_b[kNumStages];
// Fill D/A/B pointers
#pragma unroll
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
}
// D/A/B shared memory
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (threadIdx.x == 0) {
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive only at the leader CTA
@@ -136,11 +131,12 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// Arrive only at the leader CTA
tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads);
}
if constexpr (kTensorCoreUtilControl < 100)
tensor_core_full_barrier->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
} else if (warp_idx == 2) {
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
@@ -148,100 +144,69 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
uint32_t phase = 0;
auto launch_k_iterations = [&](const auto& func) {
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K);
const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages;
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// TODO: refactor here
if (num_last_stages == 0) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1)
func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages);
} else {
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1)
func(k_iter, DivisibleK{}, false, num_last_stages);
func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1;
}
};
auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) {
DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2,
"Too many epilogue stages, please modify the Python heuristic as well");
accum_stage_idx == 0 ? func(0) : func(1);
// Flip phases only if reach the next first stage
stage_idx = (stage_idx + 1) % kNumStages;
phase ^= stage_idx == 0;
};
// Dispatch warps into different roles
if (warp_idx == 0) {
if (warp_idx == 0 and cute::elect_one_sync()) {
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
#pragma unroll
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait(phase ^ 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_block_idx = k_iter * kNumStages + s;
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Issue TMAs
if (cute::elect_one_sync()) {
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
}
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
if (is_leader_cta and cute::elect_one_sync())
full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
if (not is_leader_cta and cute::elect_one_sync())
full_barriers[s]->arrive(0u);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait(phase ^ 1);
if (is_leader_cta and cute::elect_one_sync())
full_barriers[s]->arrive();
if (not is_leader_cta and cute::elect_one_sync())
full_barriers[s]->arrive(0u);
// Issue TMAs
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
if (is_leader_cta) {
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
} else {
full_barriers[stage_idx]->arrive(0u);
}
});
}
}
} else if (warp_idx == 1 and is_leader_cta) {
// MMA issue warp
@@ -268,88 +233,89 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
// Wait tensor memory empty barrier arrival
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
// Wait tensor memory empty barrier arrival
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
tcgen05_after_thread_sync();
// UMMA and empty barrier arrival alias
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
};
// Launch MMAs
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA arrival
full_barriers[stage_idx]->wait(phase);
tcgen05_after_thread_sync();
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) {
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[s]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
};
// Launch MMAs
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
// Issue UMMA in the leader CTA
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_F16BF16_SS, SM100_MMA_F16BF16_2x1SM_SS>;
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA arrival
full_barriers[s]->wait(phase);
tcgen05_after_thread_sync();
// Let tensor cores relax for lower possibility of frequency drop
DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
if constexpr (kTensorCoreUtilControl < 100) {
constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull;
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
const auto& start_clock = clock64();
if (cute::elect_one_sync())
while (clock64() - start_clock < kNumDummyCycles) {}
__syncwarp();
}
// Issue UMMA in the leader CTA
using cute_mma_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_MMA_F16BF16_SS <cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>,
cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>>;
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s);
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s);
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
cute_mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_iter > 0 or s > 0 or k > 0,
runtime_instr_desc);
}
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_block_idx > 0 or k > 0,
runtime_instr_desc);
}
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1);
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait(phase);
empty_barrier_arrive(s, false);
}
});
});
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
// Let tensor cores relax for lower possibility of frequency drop
DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
if constexpr (kTensorCoreUtilControl < 100) {
// For utilization control
umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
// Wait for last UMMA to be done
tensor_core_full_barrier->wait(tensor_core_phase);
tensor_core_phase ^= 1;
// Sleep for certain cycles
constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull;
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
const auto& start_clock = clock64();
if (cute::elect_one_sync())
while (clock64() - start_clock < kNumDummyCycles) {}
__syncwarp();
}
}
}
// To safely deconstruct barriers, we need another round of waits
const auto& iter_idx = scheduler.current_iter - 1;
if (kNumMulticast > 1 and iter_idx >= 0) {
const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
}
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
// Epilogue warp groups
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
@@ -363,129 +329,114 @@ sm100_bf16_gemm_impl(int* grouped_layout,
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Share store pipeline between blocks
uint32_t tma_stage_idx = 0;
auto advance_store_pipeline = [&]() {
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
};
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
// Flush TMA stores
// NOTES: for the first store, we have to flush all previous TMA,
// as we don't share pipeline stages between two blocks
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Iterate over M waves
// Iterate over M waves
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
// Wait shared memory to be released
if (epilogue_warp_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
// The pipeline stage
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Store into shared memory
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s) {
// Wait shared memory to be released
const uint32_t iter_idx = w * kNumStores + s;
if (iter_idx >= kNumTMAStoreStages) {
if (epilogue_thread_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
}
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// The pipeline stage
const auto tma_stage_idx = iter_idx % kNumTMAStoreStages;
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
__syncwarp();
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
if (epilogue_thread_idx == 0) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
}
});
}
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
__syncwarp();
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
}
}
}
}
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
if (epilogue_warp_idx == 1)
Allocator().free(0, kNumTmemCols);
}
// To safely deconstruct all barriers, we need a cluster sync
// TODO: optimize it by another round of barrier waits
if constexpr (kNumMulticast > 1)
cute::cluster_sync();
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");

View File

@@ -0,0 +1,265 @@
#pragma once
#include <cute/arch/cluster_sm90.hpp>
#include <cute/util/type_traits.hpp>
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm100;
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSplitFactor,
uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages, uint32_t kNumThreads>
__global__ void __launch_bounds__(kNumThreads, 1)
sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Configs
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t kNumTMAStoreStages = 2;
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// Shared memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
// Real tensor memory size and offsets
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_N>();
// Fill D/A/B
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 2 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(1);
}
tmem_full_barrier->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (warp_idx == 2) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
// Block indices
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
if (warp_idx == 0) {
// TMA load warp
for (uint32_t s = 0; s < num_total_stages; ++ s) {
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
uint32_t m_idx = BLOCK_M * m_block_idx;
uint32_t n_idx = BLOCK_N * n_block_idx;
uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
uint32_t k_idx = sk_idx % SHAPE_K;
uint32_t s_idx = sk_idx / SHAPE_K;
// Issue TMAs
if (cute::elect_one_sync()) {
tma_copy<BLOCK_K, BLOCK_M, kSwizzleABMode, 1>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
tma_copy<BLOCK_K, BLOCK_N, kSwizzleABMode, 1>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
}
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
if (cute::elect_one_sync())
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
}
} else if (warp_idx == 1) {
// MMA issue warp
// NOTES: only the leader CTA will do this
// Make instruction descriptor
constexpr uint32_t UMMA_M = LAYOUT_AD_M;
constexpr uint32_t UMMA_N = BLOCK_N;
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto a_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
auto b_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Wait tensor memory empty barrier arrival
tcgen05_after_thread_sync();
// Launch MMAs
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
tcgen05_after_thread_sync();
// Issue UMMA in the leader CTA
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx);
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx);
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
a_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(a_desc_base_lo, 0, k * UMMA_K);
b_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
}
}
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
}
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
}
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
// i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
if (warp_idx == 2)
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Wait UMMA arrival
tmem_full_barrier->wait(0);
tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s) {
// Wait shared memory to be released
if (s >= kNumTMAStoreStages) {
if (warp_idx == 0 and cute::elect_one_sync())
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(kNumThreads).sync();
}
// The pipeline stage
const auto tma_stage_idx = s % kNumTMAStoreStages;
const auto m_idx = m_block_idx * BLOCK_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumThreads).sync();
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
}
}
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is doing TMA stores
if (warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
#endif
}
}

View File

@@ -4,6 +4,7 @@
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/epilogue_utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
@@ -17,11 +18,12 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumStages,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t>
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
typename epilogue_type_t>
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
@@ -96,8 +98,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == 0) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
@@ -107,30 +108,25 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
cute::prefetch_tma_descriptor(&tensor_map_c);
}
// Data on shared memory (layout as ordered below)
cd_dtype_t* smem_cd[kNumTMAStoreStages];
cutlass::float_e4m3_t* smem_a[kNumStages];
cutlass::float_e4m3_t* smem_b[kNumStages];
uint32_t* smem_sfa[kNumStages];
uint32_t* smem_sfb[kNumStages];
// D/A/B shared memory
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// Fill D/A/B pointers
#pragma unroll
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
}
// Fill SFA/SFB
// SFA/SFB shared memory
auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_sfa[i] = reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
smem_sfb[i] = reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
}
auto smem_sfa = PatternVisitor([=](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
});
auto smem_sfb = PatternVisitor([=](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer +
@@ -148,7 +144,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (threadIdx.x == 0) {
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive at all CTAs
@@ -166,9 +162,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
} else if (warp_idx == 2) {
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
@@ -176,108 +171,75 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
uint32_t phase = 0;
auto launch_k_iterations = [&](const auto& func) {
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K);
const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages;
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// TODO: refactor here
if (num_last_stages == 0) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1)
func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages);
} else {
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1)
func(k_iter, DivisibleK{}, false, num_last_stages);
func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1;
}
};
auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) {
DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2,
"Too many epilogue stages, please modify the Python heuristic as well");
accum_stage_idx == 0 ? func(0) : func(1);
// Flip phases only if reach the next first stage
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
phase ^= stage_idx == 0;
};
// Dispatch warps into different roles
if (warp_idx == 0) {
if (warp_idx == 0 and cute::elect_one_sync()) {
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait(phase ^ 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_block_idx = k_iter * kNumStages + s;
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Issue TMAs
if (cute::elect_one_sync()) {
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
}
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
// Issue SFA and SFB TMAs at certain stages
// No swizzling, so one TMA for one SF is enough
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
tma_copy<BLOCK_M, 1, 0, 1>(&tensor_map_sfa, full_barriers[s], smem_sfa[s], m_block_idx * BLOCK_M,
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
tma_copy<BLOCK_N, 1, 0, 1>(&tensor_map_sfb, full_barriers[s], smem_sfb[s], n_block_idx * BLOCK_N,
scheduler.template get_global_idx<true, KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
}
// Arrive at full barriers
if (cute::elect_one_sync())
full_barriers[s]->arrive_and_expect_tx(num_arrival_bytes);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait(phase ^ 1);
if (cute::elect_one_sync())
full_barriers[s]->arrive();
// Issue TMAs
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx);
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
// Issue SFA and SFB TMAs at certain stages
// No swizzling, so one TMA for one SF is enough
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
tma_copy<BLOCK_M, 1, 0, 1>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M,
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
tma_copy<BLOCK_N, 1, 0, 1>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N,
scheduler.template get_global_idx<true, KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
}
});
// Arrive at full barriers
full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
}
}
} else if (warp_idx == 1 and is_leader_cta) {
// MMA issue warp
@@ -307,101 +269,93 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
// Wait tensor memory empty barrier arrival
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
// Wait tensor memory empty barrier arrival
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
tcgen05_after_thread_sync();
// Empty barrier arrival
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
};
// Launch MMAs
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA and SF-transpose arrival
with_sf_full_barriers[stage_idx]->wait(phase);
tcgen05_after_thread_sync();
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) {
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[s]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
};
// Launch MMAs
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
// Do SF copy at certain stages
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
// SFA and SFB copy
// TODO: process shared memory descriptor by addition
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA and SF-transpose arrival
with_sf_full_barriers[s]->wait(phase);
tcgen05_after_thread_sync();
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
}
}
__syncwarp();
// Do SF copy at certain stages
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
// SFA and SFB copy
// TODO: process shared memory descriptor by addition
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[s] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[s] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
}
}
__syncwarp();
// Issue UMMA in the leader CTA
using cute_mma_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_MMA_MXF8F6F4_SS <cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>,
cute::SM100_MMA_MXF8F6F4_2x1SM_SS<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>>;
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s);
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s);
// Issue UMMA in the leader CTA
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_MXF8F6F4_SS, SM100_MMA_MXF8F6F4_2x1SM_SS>;
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
cute_mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_iter > 0 or s > 0 or k > 0,
runtime_instr_desc,
kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
kTmemStartColOfSFB);
}
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_block_idx > 0 or k > 0,
runtime_instr_desc,
kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
kTmemStartColOfSFB);
}
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1);
}
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
with_sf_full_barriers[s]->wait(phase);
empty_barrier_arrive(s, false);
}
});
});
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
}
}
// To safely deconstruct barriers, we need another round of waits
const auto& iter_idx = scheduler.current_iter - 1;
if (kNumMulticast > 1 and iter_idx >= 0) {
const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
}
} else if (warp_idx == 2) {
// UTCCP transposer
@@ -418,43 +372,30 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
};
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA arrival
full_barriers[stage_idx]->wait(phase);
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA arrival
full_barriers[s]->wait(phase);
// Transpose for UTCCP at certain stages
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfa[s] + i * kNumUTCCPAlignedElems);
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfb[s] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
// Arrive
with_sf_full_barriers[s]->arrive(0u);
// Transpose for UTCCP at certain stages
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait(phase);
with_sf_full_barriers[s]->arrive(0u);
}
});
// Arrive
with_sf_full_barriers[stage_idx]->arrive(0u);
}
}
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
// Epilogue warp groups
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
@@ -468,129 +409,113 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Share store pipeline between blocks
uint32_t tma_stage_idx = 0;
auto advance_store_pipeline = [&]() {
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
};
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
// Flush TMA stores
// NOTES: for the first store, we have to flush all previous TMA,
// as we don't share pipeline stages between two blocks
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Iterate over M waves
// Iterate over M waves
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
// Wait shared memory to be released
if (epilogue_warp_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
// The pipeline stage
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
// Store into shared memory
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s) {
// Wait shared memory to be released
const uint32_t iter_idx = w * kNumStores + s;
if (iter_idx >= kNumTMAStoreStages) {
if (epilogue_thread_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
}
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// The pipeline stage
const auto tma_stage_idx = iter_idx % kNumTMAStoreStages;
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
__syncwarp();
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
if (epilogue_thread_idx == 0) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
}
});
}
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumEpilogueThreads, 0);
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
}
}
}
}
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
if (epilogue_warp_idx == 1)
Allocator().free(0, kNumTmemCols);
}
// To safely deconstruct all barriers, we need a cluster sync
// TODO: optimize it by another round of barrier waits
if constexpr (kNumMulticast > 1)
cute::cluster_sync();
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");

View File

@@ -5,6 +5,7 @@
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/epilogue_utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
@@ -22,7 +23,8 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, typename cd_dtype_t>
GemmType kGemmType, typename cd_dtype_t,
typename epilogue_type_t>
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
@@ -88,8 +90,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == 0) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
@@ -133,7 +134,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (threadIdx.x == 0) {
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive at all CTAs
@@ -149,9 +150,8 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
} else if (warp_idx == 2) {
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
@@ -174,7 +174,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// Register configurations
constexpr uint32_t kNumNonEpilogueRegisters = 64;
@@ -435,7 +435,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// as we don't share pipeline stages between two blocks
if (epilogue_thread_idx_in_warpgroup == 0)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx);
// Write shared memory
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
@@ -449,13 +449,13 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
if (s >= kNumTMAStoreStages) {
if (epilogue_thread_idx_in_warpgroup == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx);
}
// The pipeline stage
const auto tma_stage_idx = s % kNumTMAStoreStages;
const auto m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_m, BLOCK_M, m_block_idx);
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
const auto local_smem_cd = smem_cd[tma_stage_idx] + epilogue_warpgroup_idx * STORE_BLOCK_M * STORE_BLOCK_N;
// Store into shared memory
@@ -502,7 +502,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
cutlass::arch::NamedBarrier::sync(STORE_BLOCK_M, epilogue_warpgroup_idx);
if (epilogue_thread_idx_in_warpgroup == 0) {
cute::SM90_TMA_STORE_2D::copy(
&tensor_map_d, local_smem_cd,
@@ -512,10 +512,6 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
}
}
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
if (epilogue_thread_idx_in_warpgroup == 0)
cute::tma_store_wait<0>();
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
if (epilogue_warp_idx == 1)

View File

@@ -25,7 +25,8 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs, GemmType kGemmType>
uint32_t kNumSMs, GemmType kGemmType,
typename cd_dtype_t>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_bf16_gemm_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
@@ -44,7 +45,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Shared memory
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(cd_dtype_t);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
@@ -55,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
const uint32_t lane_idx = get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == kNumMathThreads) {
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
@@ -67,7 +68,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Data on shared memory
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
__nv_bfloat16* smem_a[kNumStages];
__nv_bfloat16* smem_b[kNumStages];
@@ -91,7 +92,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
}
// Initialize barriers
if (threadIdx.x == kNumMathThreads) {
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
@@ -99,7 +100,6 @@ sm90_bf16_gemm_impl(int* grouped_layout,
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
}
@@ -125,14 +125,14 @@ sm90_bf16_gemm_impl(int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) {
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto divisible_type) {
@@ -203,7 +203,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
}
};
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Launch MMAs
launch_k_iterations([&](uint32_t k_iter, auto divisible_type) {
@@ -237,11 +237,10 @@ sm90_bf16_gemm_impl(int* grouped_layout,
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(s);
}
// Notify barrier arrival
empty_barrier_arrive(s);
}
// Wait unaligned cases
@@ -256,7 +255,6 @@ sm90_bf16_gemm_impl(int* grouped_layout,
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
"Unaligned TMA store or too many TMA store instructions");
@@ -265,60 +263,76 @@ sm90_bf16_gemm_impl(int* grouped_layout,
// Wait last TMA store to be finished
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
if constexpr (std::is_same_v<cd_dtype_t, cutlass::bfloat16_t>) {
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// Swizzle or padding into the correct address
uint8_t* smem_ptr = nullptr;
if constexpr (kSwizzleDMode > 0) {
// Calculate the swizzling atom offset and in-atom offset
constexpr uint32_t kNumBankGroupBytes = 16;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// Swizzle or padding into the correct address
uint8_t* smem_ptr = nullptr;
if constexpr (kSwizzleDMode > 0) {
// Calculate the swizzling atom offset and in-atom offset
constexpr uint32_t kNumBankGroupBytes = 16;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
col ^= row % (kSwizzleDMode / 16);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
col ^= row % (kSwizzleDMode / 16);
// Add back into the base pointer
// NOTES: think twice before modifying this, as changes may affect the number of instructions
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
m_offset * kSwizzleDMode + // Wave offset
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
} else {
// No swizzling, just padding
// TODO: support more cases
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
// Add back into the base pointer
// NOTES: think twice before modifying this, as changes may affect the number of instructions
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
m_offset * kSwizzleDMode + // Wave offset
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
} else {
// No swizzling
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
}
// NOTES: only 16 lanes' addresses are used
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_ptr
);
}
}
}
else {
// Use `st.shared` if STSM is not available
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2);
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
}
// NOTES: only 16 lanes' addresses are used
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_ptr
);
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Use TMA store to write back to global memory
// TODO: compatible with FP32 output
constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked;
DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {

View File

@@ -0,0 +1,173 @@
#pragma once
#include <cute/arch/cluster_sm90.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSplitFactor,
uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
float *d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Types
using WGMMA = typename BF16MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Shared memory
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
// Configs
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_idx();
DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
// Fill shared memory pointers
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumMathThreads);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = 232;
// Block indices
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
// Persistently schedule over blocks
#pragma unroll
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
const uint32_t& k_idx = sk_idx % SHAPE_K;
const uint32_t& s_idx = sk_idx / SHAPE_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
float accum[WGMMA::kNumAccum] = {0};
// Launch MMAs
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrivals
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, 1);
}
warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
empty_barriers[stage_idx]->arrive();
}
const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
if (col + i * 8 >= SHAPE_N)
break;
if (row < SHAPE_M) {
atomicAdd(reinterpret_cast<float2*>(d + (row + 0) * SHAPE_N + col + i * 8),
make_float2(accum[i * 4 + 0], accum[i * 4 + 1]));
}
if (row + 8 < SHAPE_M) {
atomicAdd(reinterpret_cast<float2*>(d + (row + 8) * SHAPE_N + col + i * 8),
make_float2(accum[i * 4 + 2], accum[i * 4 + 3]));
}
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
}; // namespace deep_gemm

View File

@@ -1,3 +1,348 @@
#pragma once
// TODO: add implement
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumGroups,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, typename cd_dtype_t>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
int* grouped_layout,
cute::TmaDescriptor* tensor_map_buffer,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a_base,
const __grid_constant__ cute::TmaDescriptor tensor_map_b_base,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads");
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Shared memory
static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
// Configs
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = threadIdx.x % 32;
// Prefetch TMA descriptors at the very beginning
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a_base);
cute::prefetch_tma_descriptor(&tensor_map_b_base);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
cute::prefetch_tma_descriptor(&tensor_map_sfb);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Tensor maps on shared and global memory
auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * i);
});
auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * (2 + i));
});
auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; });
auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; });
// Data on shared memory
auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
});
auto smem_sfb = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
});
// Barriers on shared memory
constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
auto full_barriers = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
});
auto empty_barriers = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
});
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
// Load tensormap A/B to shared memory
if constexpr (kGemmType == GemmType::KGroupedContiguous) {
*smem_tensor_map_a[0] = tensor_map_a_base;
*smem_tensor_map_a[1] = tensor_map_a_base;
*smem_tensor_map_b[0] = tensor_map_b_base;
*smem_tensor_map_b[1] = tensor_map_b_base;
}
// Initialize barriers
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// Pipeline unroll control
constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages);
// Register reconfigurations (more math registers are needed with unrolling)
constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
// TMA and MMA pipeline
const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
};
uint32_t iter_idx = 0;
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base;
const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
uint32_t last_group_idx = kNumGroups, sum_k = 0;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
const uint32_t& m_idx = m_block_idx * BLOCK_M;
const uint32_t& n_idx = n_block_idx * BLOCK_N;
if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) {
const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1;
const uint32_t& next_stage_idx = stage_idx ^ 1;
last_group_idx = scheduler.current_group_idx;
// Prepare next tensor map
sum_k += scheduler.current_shape_k;
if (scheduler.next_group_idx < kNumGroups) {
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + sum_k * shape_m);
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + sum_k * shape_n);
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
*(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]);
*(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]);
tensor_map_release_cta();
}
// Get current tensor map
if (scheduler.current_num_valid_groups > 0) {
tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]);
tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]);
current_tensor_map_a = gmem_tensor_map_a[stage_idx];
current_tensor_map_b = gmem_tensor_map_b[stage_idx];
}
}
#pragma unroll kNumPipelineUnrolls
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
// Wait consumer release
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
empty_barriers[stage_idx]->wait(phase ^ 1);
// Issue TMA
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t& k_idx = k_block_idx * BLOCK_K;
const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier), smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
tma_copy(&tensor_map_sfb, reinterpret_cast<uint64_t*>(&full_barrier), smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
tma_copy(current_tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier), smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
tma_copy(current_tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier), smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
}
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s) {
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
empty_barriers[stage_idx]->wait(phase ^ 1);
}
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Accumulation for WGMMA or CUDA promotion
DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K);
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
float2 scales_b[WGMMA::kNumAccum / 4];
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
}
};
#pragma unroll kNumPipelineUnrolls
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
// Wait TMA arrivals
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
full_barriers[stage_idx]->wait(phase);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0);
auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1);
// Read B scales
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(stage_idx);
// Promote with scales
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
const float &scale_b_0 = scales_b[i].x;
const float &scale_b_1 = scales_b[i].y;
final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
}
}
// Flush previous stores
if (warp_idx % 4 == 0 and cute::elect_one_sync())
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
// Store to D shared memory
const auto& smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
const auto& smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
// Use TMA store to write back to global memory
if (warp_idx % 4 == 0 and cute::elect_one_sync()) {
cute::SM90_TMA_REDUCE_ADD_2D::copy(
&tensor_map_d, smem_d_0, n_block_idx * BLOCK_N,
current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0);
cute::tma_store_arrive();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -10,6 +10,7 @@
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include <deep_gemm/common/epilogue_utils.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
@@ -18,15 +19,15 @@ namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd>
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) {
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
if (num_former_iters == kNumFormerIters) {
inner_launch_k_iterations(func, cute::Int<kNumFormerIters>{});
func(cute::Int<kNumFormerIters>{});
return;
}
if constexpr (kNumFormerIters + kGap <= kEnd)
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
dispatch_num_former_iters<kNumFormerIters + kGap, kGap, kEnd>(num_former_iters, func);
}
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
@@ -36,7 +37,8 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs, GemmType kGemmType>
uint32_t kNumSMs, GemmType kGemmType,
typename epilogue_type_t>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
@@ -69,14 +71,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
const uint32_t& smem_sfb_size = align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == kNumMathThreads) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
@@ -90,35 +90,26 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Data on shared memory
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
__nv_fp8_e4m3* smem_a[kNumStages];
__nv_fp8_e4m3* smem_b[kNumStages];
float* smem_sfa[kNumStages];
float* smem_sfb;
// TMA Barrier for both divisible and non-divisible cases
Barrier* full_barriers[kNumStages];
Barrier* empty_barriers[kNumStages];
// Fill shared memory pointers
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_sfa[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE);
}
smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE));
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE);
});
auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE);
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i] = barrier_start_ptr + i;
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
}
auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
if (threadIdx.x == kNumMathThreads) {
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll
@@ -128,107 +119,72 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
struct SkipComputation {};
struct NotSkipComputation {};
auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) {
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
// NOTES: for too-many branches (> 5), we disable this optimization
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) {
if (skip_computation) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type);
} else if (shape_k % kFullKOfAllStages == 0) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
} else {
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
}
}, func, kShouldOptimize ? num_former_iters : 0);
};
// Register reconfigurations
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = 232;
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// Flip phases only if reach the next first stage
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
phase ^= stage_idx == 0;
};
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) {
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
// Issue TMA A
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t k_idx = k_block_idx * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a);
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier),
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_block_idx),
num_tma_multicast_a);
// Issue TMA A
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[s];
uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a);
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier),
smem_sfa[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_idx / BLOCK_K),
num_tma_multicast_a);
// Issue TMA B
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
}, false, 0);
// Issue TMA B
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
}
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
for (uint32_t i = 0; i < kNumStages; advance_pipeline(i))
empty_barriers[stage_idx]->wait(phase ^ 1);
}
}
} else {
@@ -239,6 +195,11 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
auto b_desc = make_smem_desc(smem_b[0], 1);
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
@@ -259,7 +220,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
st_shared(smem_sfb + i, __ldg(local_sfb + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Accumulation for WGMMA or CUDA promotion
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2);
@@ -267,90 +228,96 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s) {
auto empty_barrier_arrive = [&]() {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void();
} else {
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void();
}
};
// Launch MMAs
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) {
constexpr bool kSkipComputation = cute::is_same_v<decltype(skip_type), SkipComputation>;
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages);
// Skip useless computations
if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
// The compiler must know the dynamic variable `num_former_iters`'s real value
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_sfb + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_sfb + k_iter * kNumStages + s + shape_k_scales);
// Dispatch `num_former_iters` and launch MMAs
dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
#pragma unroll 8
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
// Read B scales
float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales);
// TODO: remove some useless computation for unaligned Ms
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Wait TMA arrivals
full_barriers[stage_idx]->wait(phase);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_sfa[s] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_sfa[s] + r_1 + m_offset);
// TODO: remove some useless computation for unaligned Ms
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1 + m_offset);
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(s);
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
WGMMA::wgmma(a_desc, b_desc, accum, k);
}
warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive();
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
}
}
}
// Wait unaligned cases
});
} else {
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
empty_barrier_arrive(s);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
full_barriers[stage_idx]->wait(phase);
empty_barrier_arrive();
}
}, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters);
}
// TMA checks
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
@@ -364,7 +331,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Wait last TMA store to be finished
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
@@ -413,7 +380,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Use TMA store to write back to global memory
// TODO: compatible with FP32 output
@@ -423,7 +390,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
n_block_idx * BLOCK_N + in_block_n_offset,
epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset),
scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
}

View File

@@ -91,7 +91,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) if not using_nsys else None
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
with profiler:
for i in range(2):
@@ -112,10 +112,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
is_tuple = isinstance(kernel_names, tuple)
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
if not with_multiple_kernels:
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
@@ -136,6 +135,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
total_num += int(num_str)
break
kernel_times.append(total_time / total_num)
kernel_times.append(total_time / total_num if total_num > 0 else 0)
return tuple(kernel_times) if is_tuple else kernel_times[0]

View File

@@ -16,13 +16,16 @@ def ceil_to_ue8m0(x: torch.Tensor):
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
assert x.dim() == 2
m, n = x.shape
x_view = x.view(m, -1, 128)
padded_n = align(n, 128)
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
x_padded[:, :n] = x
x_view = x_padded.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -54,4 +57,4 @@ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled, sf.squeeze()
return x_scaled, sf.squeeze()