Multiple updates and refactorings (#280)

This commit is contained in:
Zhean Xu
2026-01-16 17:06:52 +08:00
committed by GitHub
parent 3ccf40c53a
commit 0f5f266202
55 changed files with 2706 additions and 891 deletions

View File

@@ -51,6 +51,8 @@ struct Scheduler {
uint32_t current_group_idx = 0;
// Only used for masked layout
uint32_t current_m_cumsum = 0;
// Only used for countiguous psum layout
uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0;
// Only used for k-grouped layout
uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
uint32_t next_group_idx, next_shape_k;
@@ -72,12 +74,16 @@ struct Scheduler {
current_shape_k = shape_k;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
num_blocks = num_m_blocks * num_n_blocks;
} else if (kGemmType == GemmType::MGroupedContiguous) {
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
num_blocks = num_m_blocks * num_n_blocks;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::MGroupedMasked) {
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::KGroupedContiguous) {
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
this->grouped_layout = grouped_layout;
current_psum_m = __ldg(grouped_layout);
num_m_blocks = ceil_div(current_psum_m, BLOCK_M);
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
this->grouped_layout = grouped_layout;
get_next_k_group(current_group_idx, current_shape_k);
next_group_idx = current_group_idx + 1;
@@ -131,7 +137,7 @@ struct Scheduler {
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
} else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
const auto offset = kWithGroupOffset ? current_group_idx : 0;
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
@@ -172,6 +178,28 @@ struct Scheduler {
}
get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
while (true) {
// Within current group
if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks)
break;
// Move to check the next group
if (++ current_group_idx == kNumGroups)
return false;
// NOTES: `num_m_blocks` varies with the increase of the group index
last_psum_m = align(current_psum_m, 128u);
current_psum_m = __ldg(grouped_layout + current_group_idx);
current_m_block_cumsum += num_m_blocks;
num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M);
}
get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx);
// NOTES: `last_psum_m` is aligned with 128
m_block_idx += last_psum_m / BLOCK_M;
DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M");
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
while (true) {
// End of the task
@@ -248,6 +276,9 @@ struct Scheduler {
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx);
} else {
// Unreachable
DG_TRAP_ONLY_DEVICE_ASSERT(false);
}
}
};

View File

@@ -97,7 +97,8 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
const auto& layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
const auto& num_non_contiguous = 128 / get_atom_base(layout_type);
if constexpr (kMajorMode == cute::UMMA::Major::K) {
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
// NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)`
// also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
@@ -131,8 +132,8 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
}
__device__ __forceinline__
uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sf_id) {
desc.a_sf_id_ = sf_id, desc.b_sf_id_ = sf_id;
uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) {
desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id;
return static_cast<uint64_t>(static_cast<uint32_t>(desc)) << 32;
}
@@ -154,6 +155,20 @@ __device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
__device__ __forceinline__
void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbarrier_addr), "l"(cache_hint)
: "memory"
);
}
// UMMA versions with relaxed assertions
struct SM100_MMA_F16BF16_SS {
__device__ static void
@@ -231,4 +246,21 @@ struct SM100_MMA_MXF8F6F4_2x1SM_SS {
}
};
struct SM100_MMA_F16BF16_WS_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
} // namespace `deep_gemm::sm100`

View File

@@ -152,6 +152,51 @@ struct BF16MMASelector {
using type = decltype(select_type());
};
template <int N_, typename MMA>
struct TF32MMARS {
template <size_t ...Idx>
__forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
using namespace cute::SM90::GMMA;
MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
}
__forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) {
call_fma_impl(reinterpret_cast<uint32_t*>(a), desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
}
static constexpr int M = 64;
static constexpr int N = N_;
static constexpr int K = 8;
static constexpr int kNumAccum = M * N / 128;
};
template <int N, bool kUseRS = true>
struct TF32MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (kUseRS) {
if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN();
if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN();
if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN();
if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN();
if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN();
if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN();
DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N");
}
}
static constexpr auto select_type() {
if constexpr (kUseRS) {
return TF32MMARS<N, decltype(select_mma())>();
} else {
DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now");
}
}
using type = decltype(select_type());
};
template <typename dtype_t>
struct SM90_U32x2_STSM_N {

View File

@@ -2,14 +2,36 @@
namespace deep_gemm {
enum class GemmType {
Normal = 0,
MGroupedContiguous = 1,
MGroupedMasked = 2,
KGroupedContiguous = 3,
Batched = 4
enum class MmaKind {
BF16 = 0,
MXFP8FP4 = 1,
};
constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) {
switch (mma_kind) {
case MmaKind::BF16: return 2;
case MmaKind::MXFP8FP4: return 1;
default: return 0;
}
}
enum class GemmType {
Normal = 0,
MGroupedContiguous = 1,
MGroupedMasked = 2,
KGroupedContiguous = 3,
Batched = 4,
MGroupedContiguousWithPsumLayout = 5,
};
constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) {
switch (gemm_type) {
case GemmType::MGroupedContiguous: return true;
case GemmType::MGroupedContiguousWithPsumLayout: return true;
default: return false;
}
}
enum class KernelType {
Kernel1D1D = 0,
Kernel1D2D = 1,

View File

@@ -148,6 +148,10 @@ __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w));
}
__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) {
asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
}
template <typename old_t>
__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});

View File

@@ -388,7 +388,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
// The pipeline stage
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Store into shared memory

View File

@@ -14,6 +14,7 @@ namespace deep_gemm {
using namespace deep_gemm::sm100;
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kGranKA, uint32_t kGranKB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
@@ -22,7 +23,8 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
GemmType kGemmType, bool kWithAccumulation,
typename a_dtype_t, typename b_dtype_t, typename cd_dtype_t,
typename epilogue_type_t>
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
@@ -45,16 +47,21 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
constexpr uint32_t WAVE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M;
constexpr uint32_t kNumTMAStoreStages = 2;
constexpr uint32_t kNumSFStagesPerLoad = sizeof(uint32_t) / sizeof(cutlass::float_ue8m0_t);
constexpr uint32_t kNumUTCCPAlignedElems = 128;
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4;
constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4;
DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A");
DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
const uint32_t shape_sf_k = ceil_div(shape_k, BLOCK_K * kNumSFStagesPerLoad);
const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4);
const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4);
// Utils
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
@@ -78,8 +85,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
@@ -89,7 +96,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
// NOTES: Make sure we have enough shared memory for UMMA padding
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t);
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
@@ -118,10 +125,10 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
return reinterpret_cast<a_dtype_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
return reinterpret_cast<b_dtype_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// SFA/SFB shared memory
@@ -225,28 +232,31 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t, kIsBatchedMM>(
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::float_e4m3_t, kIsBatchedMM>(
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t, kIsBatchedMM>(
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::float_e4m3_t, kIsBatchedMM>(
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx);
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v<a_dtype_t, cutlass::float_e4m3_t> ? 1 : 2) +
SMEM_B_SIZE_PER_STAGE / (std::is_same_v<b_dtype_t, cutlass::float_e4m3_t> ? 1 : 2);
// Issue SFA and SFB TMAs at certain stages
// No swizzling, so one TMA for one SF is enough
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
tma_copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M,
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)));
num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
}
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
tma_copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N,
scheduler.template get_global_idx<true, IndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
scheduler.template get_global_idx<true, IndexType::SF_K>(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx));
num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
}
// Arrive at full barriers
@@ -260,9 +270,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// TODO: refactor `UMMA_M` calculation
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e4m3_t, cutlass::float_e4m3_t,
float, cutlass::float_ue8m0_t,
constexpr uint32_t UMMA_K = 32;
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
auto sf_desc = make_sf_desc(nullptr);
@@ -313,19 +322,20 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// Do SF copy at certain stages
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
// SFA and SFB copy
// TODO: process shared memory descriptor by addition
// TODO: process shared memory descriptor by addition
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad;
if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
}
const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad;
if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
@@ -337,17 +347,20 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// Issue UMMA in the leader CTA
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_MXF8F6F4_SS, SM100_MMA_MXF8F6F4_2x1SM_SS>;
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx);
const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx);
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id);
b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset");
a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K);
a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K);
mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_block_idx > 0 or k > 0,
@@ -391,11 +404,14 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
full_barriers[stage_idx]->wait(phase);
// Transpose for UTCCP at certain stages
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
@@ -454,7 +470,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
// The pipeline stage
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
// Store into shared memory

View File

@@ -143,7 +143,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
return {block_q_idx + gridDim.x, q_iter_idx + 1};
};
uint32_t seq_k_start[BLOCK_Q];
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
uint32_t start = cute::numeric_limits<uint32_t>::max();
uint32_t end = cute::numeric_limits<uint32_t>::min();
@@ -152,8 +152,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
start = min(start, min(seq_k_start[i], seq_len_kv));
end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv));
end = max(end, min(seq_k_end[i], seq_len_kv));
}
start = start / 4 * 4;
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
@@ -278,9 +279,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const auto& v_offset = lane_idx;
// Preload weights
constexpr uint32_t kNumWeightsInReg = 52;
constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads);
float weights[BLOCK_Q][kNumWeightsInReg];
DG_STATIC_ASSERT(kNumWeightsInReg <= kNumHeads and kNumWeightsInReg % 4 == 0, "Invalid kNumWeightsInReg");
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
while (block_q_idx < num_q_blocks) {
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
@@ -337,7 +338,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
float* accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
@@ -367,14 +368,14 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
sum_1 = transform_smem(j + 2, sum_1);
}
float result = sum_0.x + sum_0.y + sum_1.x + sum_1.y;
result *= scale_kv;
auto sum = __fadd2_rn(sum_0, sum_1);
float result = scale_kv * (sum.x + sum.y);
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
if constexpr (kIsCompressedLogits) {
if (kv_offset + v_offset >= seq_k_start[i])
if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i])
logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result;
} else {
logits[q_idx * stride_logits + kv_offset + v_offset] = result;

View File

@@ -22,8 +22,9 @@ template <uint32_t kNextN, uint32_t kNumHeads,
bool kIsContextLens2D,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads>
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads + 128, 1)
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint64_t logits_stride, const uint64_t block_table_stride,
const uint32_t* context_lens, float* logits,
@@ -40,9 +41,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
const auto& lane_idx = get_lane_idx();
// Prefetch TMA descriptors
static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_kv);
@@ -54,78 +53,58 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
static constexpr uint32_t SMEM_UMMA_SIZE = kNumMathWarpGroups * 2 * 8 + static_cast<uint32_t>(sizeof(uint32_t));
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Q data and barriers on shared memory
// Q and KV data on shared memory
auto smem_q = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
});
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
// Separate math warpgroups and tma load warps into KV groups
// Each math warpgroup corresponds to a tma load warp
const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
// Per group KV data and barriers on shared memory
const auto& smem_kv_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * i);
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
});
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
});
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
// UMMA barriers and TMEM pointer on shared memory
auto umma_barrier_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kNumMathWarpGroups);
// Barriers and TMEM pointer on shared memory
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups;
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
const bool& is_math_warp = (warp_idx < (kNumMathThreads / 32)); // 0 16
const bool& is_tma_load_warp = (warp_idx >= (kNumMathThreads / 32) and warp_idx < (kNumMathThreads / 32 + 4)); // 16 20
const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 4)); // 20
const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4);
const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4);
const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1);
// Initialize barriers
if (is_tma_load_warp and cute::elect_one_sync()) {
if (kv_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads);
}
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads);
}
if (kv_group_idx < kNumMathWarpGroups) {
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(128);
}
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(kNumMathThreads);
}
cutlass::arch::fence_barrier_init();
}
@@ -144,12 +123,13 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 32;
constexpr uint32_t kNumMathRegisters = 104;
constexpr uint32_t kNumSpecializedRegisters = 40;
constexpr uint32_t kNumMathRegisters = 232;
// Scheduler
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups>(batch_size, blockIdx.x, context_lens, schedule_meta);
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit>(batch_size, blockIdx.x, context_lens, schedule_meta);
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
// Q and KV pipeline
const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
@@ -161,19 +141,18 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// UMMA settings
// Construct instruction with layout F
constexpr uint32_t UMMA_M = 64;
// Construct instruction with layout D
constexpr uint32_t UMMA_M = 128;
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
constexpr uint32_t UMMA_N = kNextN * kNumHeads;
DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
if (is_tma_load_warp) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
if (kv_group_idx >= kNumMathWarpGroups)
return;
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
if (kv_group_idx == 0 and cute::elect_one_sync()) {
if (cute::elect_one_sync()) {
tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
@@ -199,6 +178,14 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
kv_idx = next_kv_idx;
num_kv = next_num_kv;
// Read KV block index
// TODO: deal with `-1`?
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0);
}
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
// Wait Q consumer release and issue TMA Q
if (prefetch_q) {
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
@@ -206,25 +193,26 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
issue_tma_q(q_stage_idx, q_idx + 1);
}
// Read KV block index
// TODO: deal with `-1`?
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
__ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0);
}
const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
int kv_block_idx[kNumBlocksPerSplit];
#pragma unroll
for (int i = 0; i < kNumBlocksPerSplit; ++ i)
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
kv_block_idx_ptr += kNumBlocksPerSplit;
// Wait KV consumer release
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
// Issue TMA KV
if (cute::elect_one_sync()) {
tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
#pragma unroll
for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
0, 0, 1, kv_block_idx[i]);
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
0, kv_block_idx[i]);
}
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
}
@@ -245,32 +233,26 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
uint32_t q_idx = batch_size, kv_idx;
uint32_t next_q_idx, next_kv_idx, next_num_kv;
uint32_t q_stage_idx, q_phase;
uint32_t umma_phase = 0;
auto smem_kv = PatternVisitor([&](const uint32_t& stage_idx) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_SIZE_PER_STAGE * stage_idx);
});
uint32_t umma_phase = 1;
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
if (q_idx != next_q_idx) {
if (q_idx != next_q_idx)
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
}
q_idx = next_q_idx;
kv_idx = next_kv_idx;
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size");
DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
empty_umma_barriers[i]->wait(umma_phase & 1);
empty_umma_barriers[i]->wait(umma_phase);
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_kv[kv_stage_idx] + i * SMEM_KV_PIPE_SIZE, 0, k * UMMA_K);
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_q[q_stage_idx], 0, k * UMMA_K);
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
@@ -285,10 +267,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Offsets
const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
float weights[kNextN][kNumHeads / 4];
const auto& sub_warp_offset = (warp_idx % 4) * 16;
const auto& v_0_offset = lane_idx / 4 + 0;
const auto& v_1_offset = lane_idx / 4 + 8;
const uint32_t thread_idx = threadIdx.x;
// Weights
constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads));
float weights[kNextN][kNumWeightsInReg];
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
uint32_t q_idx = batch_size, kv_idx;
@@ -310,9 +294,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Read weights
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
for (uint32_t j = 0; j < kNumWeightsInReg; ++ j)
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
}
@@ -321,75 +304,80 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
kv_idx = next_kv_idx;
// Calculate KV offset in advance
auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV;
// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
// Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]`
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Read per-KV scales
auto scale_kv = make_float2(ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset),
ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset));
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx);
empty_umma_barriers[warpgroup_idx]->arrive();
// Wait UMMA arrival
full_umma_barriers[warpgroup_idx]->wait(umma_phase & 1);
full_umma_barriers[warpgroup_idx]->wait(umma_phase);
umma_phase ^= 1;
// Release KV empty
empty_kv_barriers[kv_stage_idx]->arrive();
// Reduce over the head dim and store
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN;
uint32_t shifted_accum[kNumLDTMElems];
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM");
auto tmem_load = [&](auto... Is) {
if constexpr (kNumLDTMElems == 32) {
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 64) {
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 128) {
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
}
};
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
cutlass::arch::fence_view_async_tmem_load();
empty_umma_barriers[warpgroup_idx]->arrive();
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
// Load from the tensor memory
constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128;
uint32_t shifted_accum[kNumLDTMElems];
DG_STATIC_ASSERT(kNumLDTMElems == 16 or kNumLDTMElems == 32 or kNumLDTMElems == 64, "Invalid LDTM");
auto tmem_load = [&](auto... Is) {
if constexpr (kNumLDTMElems == 16) {
cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 32) {
cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 64) {
cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
}
};
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
cutlass::arch::fence_view_async_tmem_load();
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
// Transform
const auto& transform_2 = [&](const uint32_t& j, const uint32_t& k, const float2& sum) {
auto a = make_float2(fmaxf(*reinterpret_cast<float*>(&shifted_accum[j * 4 + k]), 0),
fmaxf(*reinterpret_cast<float*>(&shifted_accum[j * 4 + k + 2]), 0));
auto b = make_float2(weights[i][j * 2 + k], weights[i][j * 2 + k]);
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
// Intra-thread reduction
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 8; ++ j) {
sum_0 = transform_2(j, 0, sum_0);
sum_1 = transform_2(j, 1, sum_1);
for (int j = 0; j < kNumWeightsInReg; j += 4) {
sum_0 = transform_reg(j, sum_0);
sum_1 = transform_reg(j + 2, sum_1);
}
auto v = __fmul2_rn(__fadd2_rn(sum_0, sum_1), scale_kv);
// Inter-thread reduction
const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < 2; ++ j) {
const auto& offset = 1u << j;
v.x += __shfl_xor_sync(0xffffffffu, v.x, offset);
v.y += __shfl_xor_sync(0xffffffffu, v.y, offset);
for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
sum_0 = transform_smem(j, sum_0);
sum_1 = transform_smem(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
float result = scale_kv * (sum.x + sum.y);
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
logits[kv_offset + i * logits_stride + v_0_offset] = v.x;
logits[kv_offset + i * logits_stride + v_1_offset] = v.y;
logits[kv_offset + i * logits_stride + thread_idx] = result;
}
}
} else {

View File

@@ -0,0 +1,345 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/reduction.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm100;
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
__device__ __forceinline__
uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
// Calculate the index of the bank group to be written in the atom
const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
// - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)`
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups;
auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
col ^= row % (kSwizzleMode / kSwizzleBase);
return row * 128 + col * kSwizzleBase;
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumSplits,
uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
float* sqr_sum) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Configs
constexpr uint32_t kNumCastStages = 2;
constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
constexpr auto kMajorA = cute::UMMA::Major::K;
constexpr auto kMajorB = cute::UMMA::Major::K;
DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages");
DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads");
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Real tensor memory size and offsets
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
// Data on shared memory (layout as ordered below)
// Fill D/A/B pointers
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 4 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
full_cast_barriers[i]->init(kNumCastAndReduceThreads);
empty_barriers[i]->init(1);
empty_cast_barriers[i]->init(1);
}
tmem_full_barrier->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (warp_idx == 2) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
const uint32_t m_block_idx = block_idx / kNumSplits;
const uint32_t k_split_idx = block_idx % kNumSplits;
const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
const uint32_t m_offset = shape_m * k_split_idx;
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
// Dispatch warps into different roles
if (warp_idx < kNumMMAThreads / 32) {
// TMA load warp
if (warp_idx == 0 and cute::elect_one_sync()) {
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
// Compute offsets
uint32_t m_idx = m_block_idx * BLOCK_M;
uint32_t k_idx = k_offset + s * BLOCK_K;
// Issue TMAs
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
}
}
// MMA issue warp
if (warp_idx == 1) {
// Make instruction descriptor
constexpr uint32_t UMMA_M = BLOCK_M;
constexpr uint32_t UMMA_N = BLOCK_N;
constexpr uint32_t UMMA_K = 32 / sizeof(float);
constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float);
using umma_t = cute::SM100_MMA_TF32_TS<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
BLOCK_M, BLOCK_N, kMajorA, kMajorB>;
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Launch MMAs
// We can not unroll this part
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
const auto& cast_stage_idx = s % kNumCastStages;
full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
tcgen05_after_thread_sync();
// Issue UMMA
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
}
// Commit
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_cast_barriers[cast_stage_idx]));
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
}
// Commit to epilogue threads
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
}
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Only support layout F (M = 64) and D (M = 128)
DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M");
// Wait UMMA arrival
tmem_full_barrier->wait(0);
tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Source and destination memory address
uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup;
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) + // Base pointer
warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset
get_swizzled_smem_offset<kSwizzleCDMode>(i, lane_idx); // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
if constexpr (BLOCK_M == 64)
__syncwarp();
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0);
if (warp_idx == 0 and cute::elect_one_sync()) {
if constexpr (kNumSplits == 1) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
} else {
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
}
cute::tma_store_arrive();
}
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
if (warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
} else {
DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads");
constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32;
// TODO: make even larger block K
DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
// Launch reductions
float2 sum[2] = {float2{0, 0}, float2{0, 0}};
#pragma unroll kNumStages
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
// Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b)
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
const auto& smem_base_ptr = reinterpret_cast<uint8_t*>(smem_a[stage_idx]) + // Base pointer
sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset
// 4 lanes shared a bank group
uint32_t uint32_values[2][kNumLoads];
DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads");
#pragma unroll
for (uint32_t i = 0; i < kNumLoads; i += 2) {
auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
uint32_values[0][i + 1], uint32_values[1][i + 1],
smem_ptr);
}
// Wait tensor memory empty
const auto& cast_stage_idx = s % kNumCastStages;
empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1);
// Cast, reduce and store into tensor memory
float2 fp32x2_values[2][kNumLoads];
const auto& upper_view = reinterpret_cast<uint32_t*>(&fp32x2_values[0]);
const auto& lower_view = reinterpret_cast<uint32_t*>(&fp32x2_values[1]);
#pragma unroll
for (uint32_t i = 0; i < kNumLoads; ++ i) {
#pragma unroll
for (uint32_t u = 0; u < 2; ++ u) {
fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast<nv_bfloat162*>(&uint32_values[u][i]));
sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]);
}
// Store upper and lower part at the same time
const auto idx_0 = i * 2, idx_1 = i * 2 + 1;
cute::SM100_TMEM_STORE_16dp256b1x::copy(
upper_view[idx_0], upper_view[idx_1],
lower_view[idx_0], lower_view[idx_1],
cast_stage_idx * BLOCK_K + i * 8);
}
cutlass::arch::fence_view_async_tmem_store();
// Arrive for issuing MMAs
tcgen05_before_thread_sync();
full_cast_barriers[cast_stage_idx]->arrive();
}
// Intra-warp reduction and write back
#pragma unroll
for (uint32_t u = 0; u < 2; ++ u) {
const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y);
const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
if (lane_idx % 4 == 0 and m_idx < shape_m)
sqr_sum[m_offset + m_idx] = reduced_sum;
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
#endif
}
} // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -350,7 +350,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
// Use TMA store to write back to global memory
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;

View File

@@ -171,20 +171,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
empty_barriers[stage_idx]->wait(phase ^ 1);
// Issue TMA A
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t k_idx = k_block_idx * BLOCK_K;
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, &full_barrier,
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a);
num_tma_multicast_a, batch_idx);
tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_block_idx),
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
num_tma_multicast_a);
// Issue TMA B
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, &full_barrier,
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b);
num_tma_multicast_b, batch_idx);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
}
}
@@ -222,7 +225,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto previous_group_offset = scheduler.get_global_idx<true>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
auto previous_group_offset = scheduler.template get_global_idx<true, IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
@@ -413,9 +416,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset),
scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx));
auto n_idx = epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset);
auto m_idx = scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx);
if constexpr (kGemmType == GemmType::Batched) {
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr,
n_idx, m_idx, scheduler.current_group_idx);
} else {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx);
}
cute::tma_store_arrive();
}
__syncwarp();

View File

@@ -127,7 +127,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
return {block_q_idx + gridDim.x, q_iter_idx + 1};
};
uint32_t seq_k_start[BLOCK_Q];
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
uint32_t start = cute::numeric_limits<uint32_t>::max();
uint32_t end = cute::numeric_limits<uint32_t>::min();
@@ -136,8 +136,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
start = min(start, min(seq_k_start[i], seq_len_kv));
end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv));
end = max(end, min(seq_k_end[i], seq_len_kv));
}
start = start / 4 * 4;
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
@@ -304,9 +305,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
// NOTES: we have redundant writes here, consider more carefully
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
if constexpr (kIsCompressedLogits) {
if (kv_offset + v_0_offset >= seq_k_start[i])
if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0;
if (kv_offset + v_1_offset >= seq_k_start[i])
if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1;
} else {
logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0;

View File

@@ -58,7 +58,7 @@ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t ne
}
template <uint32_t kNextN, bool kIsContextLens2D,
uint32_t BLOCK_KV, uint32_t kNumMathWarpGroups>
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit>
struct PagedMQALogitsScheduler {
uint32_t batch_size;
const uint32_t* context_lens;
@@ -79,8 +79,8 @@ struct PagedMQALogitsScheduler {
const auto& current_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx);
const auto& end_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx + 1);
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumMathWarpGroups;
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumMathWarpGroups;
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit;
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit;
current_num_kv = get_num_kv(current_q_idx);
}
@@ -93,7 +93,7 @@ struct PagedMQALogitsScheduler {
if (q_idx == end_q_idx and kv_idx == end_kv_idx)
return false;
current_kv_idx += kNumMathWarpGroups;
current_kv_idx += kNumBlocksPerSplit;
if (current_kv_idx >= current_num_kv) {
++ current_q_idx;
current_kv_idx = 0;

View File

@@ -0,0 +1,287 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/reduction.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
__device__ __forceinline__
uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
col ^= row % kGroupsInSwizzleRange;
return (row * kNumBankGroups + col) % kGroupsInSwizzleRange;
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumSplits,
uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
float* sqr_sum) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// kSwizzleAMode and kSwizzleBMode must be 128 for now
constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K");
DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode");
DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads");
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
// Data on shared memory (layout as ordered below)
// Fill D/A/B pointers
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(128);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
__syncthreads();
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
const uint32_t m_block_idx = block_idx / kNumSplits;
const uint32_t k_split_idx = block_idx % kNumSplits;
const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
const uint32_t m_offset = shape_m * k_split_idx;
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = 256;
// TMA load warp
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
// Compute offsets
uint32_t m_idx = m_block_idx * BLOCK_M;
uint32_t k_idx = k_offset + s * BLOCK_K;
// Issue TMAs
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
}
for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
}
} else if (warp_idx < kNumMathThreads / 32) {
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
constexpr uint32_t WGMMA_M = 64;
constexpr uint32_t WGMMA_N = BLOCK_N;
constexpr uint32_t WGMMA_K = 8;
using WGMMA = typename TF32MMASelector<WGMMA_N, true>::type;
float accum[WGMMA::kNumAccum] = {0};
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
float sqr_sum_acc_0 = 0;
float sqr_sum_acc_1 = 0;
#pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128;
constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K;
float a[kNumRegPerWgmma * kNumWgmmaPerBlockK];
// Assume swizzle A mode is 128
DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
// Load BF16 A fragment from shared memory into registers, and transpose to FP32
uint32_t row = warp_idx * 16 + lane_idx / 4;
#pragma unroll
for (uint32_t i = 0; i < kNumLoads; ++ i) {
// Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a
uint32_t bank_group_idx = (row ^ i) % 8;
nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
uint32_t elem_offset = lane_idx % 4;
nv_bfloat16 a_bf16[kNumRegPerWgmma];
a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset];
a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4];
a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset];
a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4];
auto a_bf16x2_ptr = reinterpret_cast<nv_bfloat162*>(a_bf16);
auto a_float2_ptr = reinterpret_cast<float2*>(a);
float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]);
float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]);
a_float2_ptr[i * 2 + 0] = a_float2_0;
a_float2_ptr[i * 2 + 1] = a_float2_1;
sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x;
sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
}
warpgroup_wait<0>();
if (s > 0)
empty_barriers[(s - 1) % kNumStages]->arrive();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K");
#pragma unroll
for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
#pragma unroll
for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
}
}
warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
}
const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0);
const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1);
const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
if (lane_idx % 4 == 0) {
if (m_idx < shape_m)
sqr_sum[m_offset + m_idx] = reduced_sum_0;
if (m_idx + 8 < shape_m)
sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
}
warpgroup_wait<0>();
empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
// Write accum to shared memory
// Every 2 threads (one pair) will write to the same bank group (16 bytes).
// Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d
uint32_t is_odd_pair = lane_idx / 2 % 2;
// Four threads per group; write the data to the same row.
uint32_t row_idx = lane_idx / 4;
// Even/odd index pairs write to the same column, we need to reorder idx:
// group even pair indices consecutively, and likewise for odd ones.
uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx;
auto shifted_smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) +
(warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows
lane_idx % 2 * 8; // One thread of a pair writes 8 bytes
#pragma unroll
for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) {
// Get the swizzled bank group index (16 bytes per group)
uint32_t bank_group_idx = get_swizzled_bank_group_idx<kSwizzleCDMode>(i + is_odd_pair, reordered_pair_idx);
auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group
// 0/1 write to the same row, 2/3 write to another row
auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
st_shared(smem_ptr, values[0], values[1]);
st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(128, 1);
// Issue TMA stores
if (warp_idx == 0 and cute::elect_one_sync()) {
if constexpr (kNumSplits == 1) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
} else {
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
}
cute::tma_store_arrive();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
} // namespace deep_gemm
#pragma clang diagnostic pop