Add various optimizations and Mega MoE benchmarks (#316)
* Merge with private repo * Add Mega MoE Benchmark * Minor fix * Update --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
@@ -1,11 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
|
||||
#include <deep_gemm/ptx/ld_st.cuh>
|
||||
#include <deep_gemm/layout/sym_buffer.cuh>
|
||||
#include <deep_gemm/layout/mega_moe.cuh>
|
||||
|
||||
namespace deep_gemm::comm {
|
||||
|
||||
CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() {
|
||||
// Perform cluster_sync with `barrier.cluster.arrive.relaxed`
|
||||
// This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee
|
||||
cute::cluster_arrive_relaxed();
|
||||
cute::cluster_wait();
|
||||
}
|
||||
|
||||
template <uint32_t kNumSMs, uint32_t kGridSyncIndex = 0, typename sync_scope_t>
|
||||
CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace,
|
||||
const uint32_t& sm_idx, const uint32_t& thread_idx,
|
||||
|
||||
@@ -401,7 +401,8 @@ void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
// Load accumulator from TMEM
|
||||
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
|
||||
tmem_load(cute::Int<kNumHeads>{}, tmem_addr, accum);
|
||||
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
|
||||
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
|
||||
|
||||
// Release TMEM empty
|
||||
if (i == BLOCK_Q - 1) {
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace deep_gemm {
|
||||
|
||||
template <uint32_t kNextN, uint32_t kNumHeads,
|
||||
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
||||
bool kIsContextLens2D,
|
||||
bool kIsContextLens2D, bool kIsVarlen,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t SPLIT_KV,
|
||||
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
||||
@@ -30,7 +30,8 @@ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
||||
void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
const uint32_t logits_stride, const uint32_t block_table_stride,
|
||||
const uint32_t* context_lens, logits_dtype_t* logits,
|
||||
const uint32_t* block_table, const uint32_t* schedule_meta,
|
||||
const uint32_t* block_table, const uint32_t* indices,
|
||||
const uint32_t* schedule_meta,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
||||
@@ -54,10 +55,10 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
|
||||
}
|
||||
|
||||
// Next-N atom configs
|
||||
static constexpr uint32_t kNextNAtom = (kNextN % 2 == 0) ? 2 : 1;
|
||||
static constexpr uint32_t kNumNextNAtoms = kNextN / kNextNAtom;
|
||||
static constexpr bool kSingleAtom = (kNumNextNAtoms == 1);
|
||||
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
|
||||
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
|
||||
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
|
||||
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
|
||||
|
||||
// UMMA configs
|
||||
static constexpr uint32_t kNumTmemStages = 3;
|
||||
@@ -157,7 +158,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
|
||||
// Scheduler
|
||||
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
|
||||
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
|
||||
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
|
||||
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
|
||||
|
||||
// Make Q, KV and TMEM pipeline
|
||||
@@ -182,7 +183,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
|
||||
// Persistently schedule over blocks
|
||||
// Initialize outside valid range to indicate no previous task
|
||||
@@ -196,11 +197,12 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
||||
|
||||
// Issue TMA Q
|
||||
const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx);
|
||||
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
|
||||
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
|
||||
smem_q[q_stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads);
|
||||
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_atom_idx * kNextNAtom);
|
||||
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_atom_idx * kNextNAtom);
|
||||
smem_q[q_stage_idx], 0, q_token_idx * kNumHeads);
|
||||
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx);
|
||||
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx);
|
||||
full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
|
||||
}
|
||||
last_q_atom_idx = q_atom_idx;
|
||||
@@ -210,7 +212,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
} else if (warp_idx == kSpecWarpStart + 1) {
|
||||
// TMA warp for loading KV cache
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
|
||||
// Persistently schedule over blocks
|
||||
uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage;
|
||||
@@ -225,10 +227,11 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
// Coalesced load of block table
|
||||
if (kv_block_idx_ptr == 32) {
|
||||
kv_block_idx_ptr = 0;
|
||||
const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast<uint64_t>(block_table_stride);
|
||||
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
|
||||
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
|
||||
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Broadcast KV block indices
|
||||
int kv_block_idx[kNumBlocksPerSplit];
|
||||
@@ -240,7 +243,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
|
||||
// Wait KV consumer release
|
||||
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
|
||||
|
||||
|
||||
// Issue TMA KV
|
||||
if (cute::elect_one_sync()) {
|
||||
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
||||
@@ -260,7 +263,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
} else if (warp_idx == kSpecWarpStart + 2) {
|
||||
// UMMA warp
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
|
||||
|
||||
// UTCCP transposer
|
||||
@@ -371,7 +374,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
} else if (warp_idx < kSpecWarpStart) {
|
||||
// Math warpgroups for reduce
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
|
||||
const auto math_warpgroup_idx = warpgroup_idx;
|
||||
const auto math_thread_idx = warp_idx * 32 + lane_idx;
|
||||
@@ -400,6 +403,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
// Persistently schedule over blocks
|
||||
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
|
||||
uint32_t q_atom_idx, kv_idx, _;
|
||||
bool is_paired_atom = false;
|
||||
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
|
||||
if (q_atom_idx != last_q_atom_idx) {
|
||||
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
|
||||
@@ -423,11 +427,16 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
weights[i][j + 3] = raw.w;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this atom pairs two tokens from the same sequence
|
||||
if constexpr (kIsVarlen) {
|
||||
is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2);
|
||||
}
|
||||
}
|
||||
last_q_atom_idx = q_atom_idx;
|
||||
|
||||
// Calculate KV offset in advance
|
||||
auto kv_offset = q_atom_idx * kNextNAtom * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
|
||||
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
|
||||
|
||||
// Advance pipeline by `kNumMathWarpGroups` steps
|
||||
// Wait UMMA arrival
|
||||
@@ -436,53 +445,58 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
|
||||
ptx::tcgen05_after_thread_sync();
|
||||
|
||||
// Reduce over the head dim and store
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
|
||||
// Load accumulator from TMEM
|
||||
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
|
||||
tmem_load(cute::Int<kNumHeads>{}, tmem_addr, accum);
|
||||
const auto reduce_and_store = [&](auto num_iters_c) {
|
||||
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
|
||||
|
||||
// Only loop over valid iterations
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumIters; ++ i) {
|
||||
// Load accumulator from TMEM
|
||||
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
|
||||
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
|
||||
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
|
||||
|
||||
// Accumulate weighted ReLU in parallel
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
|
||||
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
||||
sum_0 = transform(j, sum_0);
|
||||
sum_1 = transform(j + 2, sum_1);
|
||||
}
|
||||
|
||||
auto sum = __fadd2_rn(sum_0, sum_1);
|
||||
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
|
||||
|
||||
// Store into the global memory
|
||||
logits[kv_offset + i * static_cast<uint64_t>(logits_stride)] = result;
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Release TMEM empty
|
||||
if (i == kNextNAtom - 1) {
|
||||
ptx::tcgen05_before_thread_sync();
|
||||
empty_tmem_barriers[tmem_stage_idx]->arrive();
|
||||
}
|
||||
ptx::tcgen05_before_thread_sync();
|
||||
empty_tmem_barriers[tmem_stage_idx]->arrive();
|
||||
};
|
||||
|
||||
// Accumulate weighted ReLU in parallel
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
|
||||
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
||||
sum_0 = transform(j, sum_0);
|
||||
sum_1 = transform(j + 2, sum_1);
|
||||
}
|
||||
|
||||
auto sum = __fadd2_rn(sum_0, sum_1);
|
||||
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
|
||||
|
||||
// Store into the global memory
|
||||
const auto dst_offset = kv_offset + i * static_cast<uint64_t>(logits_stride);
|
||||
if constexpr(sizeof(logits_dtype_t) == 2) {
|
||||
// Pack two adjacent bf16 lanes into uint32 for wider store
|
||||
uint16_t my_bits = *reinterpret_cast<const uint16_t*>(&result);
|
||||
uint16_t neighbor_bits = __shfl_down_sync(0xffffffff, my_bits, 1);
|
||||
uint32_t packed;
|
||||
asm volatile("mov.b32 %0, {%1, %2};" : "=r"(packed) : "h"(my_bits), "h"(neighbor_bits));
|
||||
if (lane_idx % 2 == 0)
|
||||
*reinterpret_cast<uint32_t*>(logits + dst_offset) = packed;
|
||||
} else {
|
||||
logits[dst_offset] = result;
|
||||
}
|
||||
// this sync warp prevent the next load tmem from reordering
|
||||
// nvcc may reorder it to overlap with the current tmem load, lead to large register usage
|
||||
__syncwarp();
|
||||
if constexpr (kIsVarlen) {
|
||||
if (is_paired_atom)
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
else
|
||||
reduce_and_store(cute::Int<1>{});
|
||||
} else if constexpr (kPadOddN) {
|
||||
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
|
||||
reduce_and_store(cute::Int<1>{});
|
||||
else
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
} else {
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ template <
|
||||
>
|
||||
CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void
|
||||
sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
int* cumulative_local_expert_recv_stats,
|
||||
const uint32_t num_tokens,
|
||||
const __grid_constant__ layout::SymBuffer<kNumRanks> sym_buffer,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts,
|
||||
@@ -91,7 +92,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
|
||||
// Workspaces
|
||||
const auto workspace = layout::Workspace(
|
||||
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk, BLOCK_M);
|
||||
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk);
|
||||
|
||||
// Token and buffer layouts
|
||||
constexpr auto fp8_token_layout = layout::Data(kHidden);
|
||||
@@ -170,7 +171,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
constexpr uint32_t UMMA_K = 32;
|
||||
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A
|
||||
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N;
|
||||
DG_STATIC_ASSERT(BLOCK_M % 32 == 0, "Invalid block M");
|
||||
DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M");
|
||||
DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N");
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
|
||||
|
||||
@@ -269,7 +270,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2);
|
||||
|
||||
// A cluster sync is essential for 2CTA tensor memory allocation
|
||||
cute::cluster_sync();
|
||||
comm::cluster_sync_with_relaxed_arrive();
|
||||
|
||||
// Initialization
|
||||
if (warp_idx == 0) {
|
||||
@@ -307,7 +308,9 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
// Allocate tensor memory
|
||||
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
cute::cluster_sync();
|
||||
// NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`,
|
||||
// and `barrier.cluster.wait.aligned` is by default `.acquire`
|
||||
comm::cluster_sync_with_relaxed_arrive();
|
||||
|
||||
// Task scheduler
|
||||
auto scheduler = sched::MegaMoEScheduler<
|
||||
@@ -599,7 +602,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Clean workspace for the next usage
|
||||
// Clean workspace for the next usage, and also do cumulative stats
|
||||
// NOTES: it is overlapped with combine reduction epilogue
|
||||
ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
|
||||
|
||||
@@ -623,19 +626,27 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
// Wait read count ready
|
||||
ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
|
||||
|
||||
// Clean expert token count
|
||||
if (thread_idx == 0)
|
||||
// Clean expert token count, and add cumulative results
|
||||
DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps");
|
||||
if (warp_idx == 0) {
|
||||
*workspace.get_expert_recv_count_sum_ptr(i) = 0;
|
||||
} else if (warp_idx == 1) {
|
||||
if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr)
|
||||
ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast<int>(num_recv_tokens));
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Clean per-rank token count
|
||||
for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads)
|
||||
*workspace.get_expert_recv_count_ptr(j, i) = 0;
|
||||
__syncwarp();
|
||||
|
||||
// Clean L1 and L2 arrival stuffs
|
||||
for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) {
|
||||
*workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0;
|
||||
*workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -672,23 +683,22 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx);
|
||||
const auto expected = scheduler.template get_valid_m<false>();
|
||||
while (ptx::ld_acq(ptr) != expected);
|
||||
} else {
|
||||
// The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival
|
||||
// NOTES: Originally we wait blocks on-demand to overlap L1 calculation
|
||||
// with L2, but this optimization is negative when `num_experts_per_wave`
|
||||
// guarantees L1's completion when L2 starts. So we remove it.
|
||||
// In the future, if `num_experts_per_wave` is not large enough
|
||||
// due to small `num_experts_per_rank`, we may need to add it back or add a switch
|
||||
DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes");
|
||||
const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx);
|
||||
// NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts
|
||||
// to avoid undefined behavior when `num_k_blocks == 32`
|
||||
const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1;
|
||||
while (ptx::ld_acq_gpu(ptr) != expected);
|
||||
}
|
||||
|
||||
uint64_t cached_l2_arrival_mask = 0;
|
||||
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) {
|
||||
// Wait current K block arrival
|
||||
if (block_phase == sched::BlockPhase::Linear2) {
|
||||
// The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2 L1 blocks' arrival
|
||||
DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes");
|
||||
const uint64_t needed = 3ull << (k_block_idx * 2);
|
||||
if ((cached_l2_arrival_mask & needed) != needed) {
|
||||
const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx);
|
||||
do {
|
||||
cached_l2_arrival_mask = ptx::ld_acq_gpu(ptr);
|
||||
} while ((cached_l2_arrival_mask & needed) != needed);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait consumer release
|
||||
empty_barriers[stage_idx]->wait(phase ^ 1);
|
||||
|
||||
@@ -953,8 +963,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
|
||||
// Load weights from global into register cache per 32 tokens
|
||||
DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size");
|
||||
DG_STATIC_ASSERT(WG_BLOCK_M % 32 == 0, "Invalid block size");
|
||||
if ((j * ATOM_M) % 32 == 0) {
|
||||
if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) {
|
||||
stored_cached_weight = *l1_topk_weights_buffer
|
||||
.get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx)
|
||||
.get_base_ptr<float>();
|
||||
@@ -1060,19 +1069,26 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
|
||||
// Only one warp per pair writes (both hold the same SF after cross-warp reduce)
|
||||
// Each lane < 4 holds SF for 2 rows (sf.x and sf.y)
|
||||
if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {
|
||||
// TODO: I believe the expression can be optimized
|
||||
const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M
|
||||
+ epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2;
|
||||
const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2;
|
||||
const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4;
|
||||
const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t);
|
||||
const auto sf_base_ptr = l2_sf_buffer.get_base_ptr<uint8_t>();
|
||||
// NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4
|
||||
// NOTES: originally there was:
|
||||
// - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2
|
||||
// - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)`
|
||||
// We find out that
|
||||
// 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside
|
||||
// 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside
|
||||
// This reduce the number of computation instructions.
|
||||
const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
|
||||
__builtin_assume(token_base_idx < BLOCK_M);
|
||||
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
|
||||
+ transform_sf_token_idx(token_idx_in_expert);
|
||||
sf_base_ptr[k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx] =
|
||||
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
|
||||
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
|
||||
sf_base_ptr[sf_addr] =
|
||||
(*reinterpret_cast<const uint32_t*>(&sf.x) >> 23);
|
||||
sf_base_ptr[k_uint_idx * mn_stride + (sf_pool_token_idx + 4) * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx] =
|
||||
sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] =
|
||||
(*reinterpret_cast<const uint32_t*>(&sf.y) >> 23);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace deep_gemm {
|
||||
|
||||
template <uint32_t kNextN, uint32_t kNumHeads,
|
||||
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
||||
bool kIsContextLens2D,
|
||||
bool kIsContextLens2D, bool kIsVarlen,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t SPLIT_KV,
|
||||
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
|
||||
@@ -30,7 +30,8 @@ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
|
||||
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
const uint32_t logits_stride, const uint32_t block_table_stride,
|
||||
const uint32_t* context_lens, logits_dtype_t* logits,
|
||||
const uint32_t* block_table, const uint32_t* schedule_meta,
|
||||
const uint32_t* block_table, const uint32_t* indices,
|
||||
const uint32_t* schedule_meta,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
||||
@@ -53,10 +54,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
||||
}
|
||||
|
||||
// Next-N atom configs
|
||||
static constexpr uint32_t kNextNAtom = (kNextN % 2 == 0) ? 2 : 1;
|
||||
static constexpr uint32_t kNumNextNAtoms = kNextN / kNextNAtom;
|
||||
static constexpr bool kSingleAtom = (kNumNextNAtoms == 1);
|
||||
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
|
||||
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
|
||||
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
|
||||
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
|
||||
|
||||
// Shared memory configs
|
||||
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
||||
@@ -136,7 +137,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
|
||||
// Scheduler
|
||||
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
|
||||
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
|
||||
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
|
||||
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
|
||||
|
||||
// Q and KV pipeline
|
||||
@@ -157,13 +158,14 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
if (warp_idx == kSpecWarpStart) {
|
||||
// TMA warp for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
||||
|
||||
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_atom_idx) {
|
||||
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) {
|
||||
if (cute::elect_one_sync()) {
|
||||
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads);
|
||||
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_atom_idx * kNextNAtom);
|
||||
const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx);
|
||||
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads);
|
||||
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx);
|
||||
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
||||
}
|
||||
};
|
||||
@@ -182,7 +184,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
|
||||
while (fetched_next_task) {
|
||||
// Prefetch next Q when (q, atom) changes
|
||||
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + 1);
|
||||
const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size);
|
||||
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance);
|
||||
|
||||
if (q_atom_idx != next_q_atom_idx)
|
||||
kv_block_idx_ptr = 32;
|
||||
@@ -195,17 +198,18 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
// TODO(xuzhean): consider -1
|
||||
if (kv_block_idx_ptr == 32) {
|
||||
kv_block_idx_ptr = 0;
|
||||
const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast<uint64_t>(block_table_stride);
|
||||
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
|
||||
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
|
||||
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
|
||||
}
|
||||
__syncwarp();
|
||||
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
|
||||
|
||||
// Wait Q consumer release and issue TMA Q
|
||||
if (prefetch_q) {
|
||||
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
||||
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
||||
issue_tma_q(q_stage_idx, q_atom_idx + 1);
|
||||
issue_tma_q(q_stage_idx, q_atom_idx + next_advance);
|
||||
}
|
||||
|
||||
uint32_t kv_block_idx[kNumBlocksPerSplit];
|
||||
@@ -236,7 +240,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
}
|
||||
} else if (warp_idx == kSpecWarpStart + 1) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
||||
|
||||
// Require full allocation
|
||||
@@ -292,7 +296,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
} else if (warp_idx < kSpecWarpStart) {
|
||||
// Math warpgroups for reduce
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
|
||||
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
|
||||
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
||||
|
||||
// Offsets
|
||||
@@ -321,6 +325,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
|
||||
uint32_t q_stage_idx, q_phase;
|
||||
uint32_t umma_phase = 0;
|
||||
bool is_paired_atom = false;
|
||||
|
||||
while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
|
||||
// Q or atom changes
|
||||
@@ -340,6 +345,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
for (uint32_t j = 0; j < kNumHeads; ++ j)
|
||||
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
|
||||
}
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2);
|
||||
}
|
||||
}
|
||||
|
||||
// Get current task indices
|
||||
@@ -347,7 +356,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
kv_idx = next_kv_idx;
|
||||
|
||||
// Calculate KV offset in advance
|
||||
auto kv_offset = q_atom_idx * kNextNAtom * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
|
||||
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
|
||||
|
||||
// Wait TMA KV arrival
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
||||
@@ -367,40 +376,56 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
// Reduce over the head dim and store
|
||||
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
|
||||
// Load accumulator from TMEM
|
||||
const auto reduce_and_store = [&](auto num_iters_c) {
|
||||
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
|
||||
float accum[kNumHeads];
|
||||
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
|
||||
|
||||
// Release TMEM empty
|
||||
if (i == kNextNAtom - 1) {
|
||||
ptx::tcgen05_before_thread_sync();
|
||||
empty_umma_barriers[math_warpgroup_idx]->arrive();
|
||||
}
|
||||
|
||||
// Accumulate weighted ReLU in parallel
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
|
||||
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
||||
sum_0 = transform(j, sum_0);
|
||||
sum_1 = transform(j + 2, sum_1);
|
||||
for (uint32_t i = 0; i < kNumIters; ++ i) {
|
||||
// Load accumulator from TMEM
|
||||
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
|
||||
|
||||
// Accumulate weighted ReLU in parallel
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
|
||||
const auto transform = [&](const uint32_t& j, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads; j += 4) {
|
||||
sum_0 = transform(j, sum_0);
|
||||
sum_1 = transform(j + 2, sum_1);
|
||||
}
|
||||
|
||||
auto sum = __fadd2_rn(sum_0, sum_1);
|
||||
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
|
||||
|
||||
// Store into the global memory
|
||||
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
auto sum = __fadd2_rn(sum_0, sum_1);
|
||||
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
|
||||
// Release TMEM empty
|
||||
ptx::tcgen05_before_thread_sync();
|
||||
empty_umma_barriers[math_warpgroup_idx]->arrive();
|
||||
};
|
||||
|
||||
// Store into the global memory
|
||||
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
|
||||
__syncwarp();
|
||||
if constexpr (kIsVarlen) {
|
||||
if (is_paired_atom)
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
else
|
||||
reduce_and_store(cute::Int<1>{});
|
||||
} else if constexpr (kPadOddN) {
|
||||
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
|
||||
reduce_and_store(cute::Int<1>{});
|
||||
else
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
} else {
|
||||
reduce_and_store(cute::Int<kNextNAtom>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace deep_gemm {
|
||||
|
||||
template <uint32_t kNextN, uint32_t kNumHeads,
|
||||
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
||||
bool kIsContextLens2D,
|
||||
bool kIsContextLens2D, bool kIsVarlen,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t SPLIT_KV,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
||||
@@ -30,11 +30,14 @@ CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
||||
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
const uint32_t logits_stride, const uint32_t block_table_stride,
|
||||
const uint32_t* context_lens, logits_dtype_t* logits,
|
||||
const uint32_t* block_table, const uint32_t* schedule_meta,
|
||||
const uint32_t* block_table, const uint32_t* indices,
|
||||
const uint32_t* schedule_meta,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
||||
DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits");
|
||||
|
||||
// Types
|
||||
using WGMMA = typename mma::sm90::FP8MMASelector<kNextN * kNumHeads>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
@@ -132,8 +135,8 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
cudaGridDependencySynchronize();
|
||||
|
||||
// Scheduler
|
||||
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups, 1>(
|
||||
blockIdx.x, context_lens, schedule_meta);
|
||||
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumMathWarpGroups, 1>(
|
||||
blockIdx.x, batch_size, context_lens, schedule_meta, indices);
|
||||
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
|
||||
|
||||
// Q and KV pipeline
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/numeric/math.hpp>
|
||||
|
||||
#include <deep_gemm/common/math.cuh>
|
||||
#include <deep_gemm/common/exception.cuh>
|
||||
|
||||
namespace deep_gemm::layout {
|
||||
|
||||
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding
|
||||
static constexpr int kNumCandidateBlockMs = 7;
|
||||
static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192};
|
||||
static constexpr int kMaxCandidateBlockM = 192;
|
||||
static constexpr int kMinCandidateBlockM = 8;
|
||||
static constexpr int kLCMCandidateBlockM = 384;
|
||||
|
||||
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk,
|
||||
T num_experts_per_rank, T block_m) {
|
||||
T num_experts_per_rank) {
|
||||
const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank;
|
||||
const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank);
|
||||
return math::constexpr_align(
|
||||
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (block_m - 1),
|
||||
block_m);
|
||||
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast<T>(kMaxCandidateBlockM) - 1),
|
||||
static_cast<T>(kLCMCandidateBlockM));
|
||||
}
|
||||
|
||||
// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M
|
||||
@@ -48,17 +56,14 @@ struct Workspace {
|
||||
const uint32_t& num_ranks,
|
||||
const uint32_t& num_experts,
|
||||
const uint32_t& num_max_tokens_per_rank,
|
||||
const uint32_t& num_topk,
|
||||
const uint32_t& block_m):
|
||||
const uint32_t& num_topk):
|
||||
base(base),
|
||||
num_ranks(num_ranks), num_experts(num_experts),
|
||||
num_max_tokens_per_rank(num_max_tokens_per_rank) {
|
||||
num_experts_per_rank = num_experts / num_ranks;
|
||||
num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank;
|
||||
num_max_pool_tokens = get_num_max_pool_tokens(
|
||||
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank, block_m);
|
||||
num_max_pool_blocks = num_max_pool_tokens / block_m;
|
||||
DG_UNIFIED_ASSERT(num_max_tokens_per_rank % block_m == 0);
|
||||
num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
|
||||
num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
|
||||
@@ -164,7 +164,7 @@ CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) {
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) {
|
||||
asm volatile("st.L1::no_allocate.relaxed.sys.u64 [%0], %1;" :: "l"(ptr), "l"(value));
|
||||
asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value));
|
||||
}
|
||||
|
||||
/// Atomics
|
||||
@@ -186,7 +186,11 @@ CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& valu
|
||||
return ret;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void red_add(const uint32_t* ptr, const uint32_t& value) {
|
||||
CUTLASS_DEVICE void red_add(const int* ptr, const int& value) {
|
||||
asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value));
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) {
|
||||
asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value));
|
||||
}
|
||||
|
||||
|
||||
@@ -6,22 +6,51 @@
|
||||
|
||||
namespace deep_gemm::sched {
|
||||
|
||||
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
|
||||
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs, bool kIsVarlen = false>
|
||||
CUTLASS_GLOBAL __launch_bounds__(32, 1)
|
||||
void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
|
||||
const uint32_t* context_lens, uint32_t* schedule_metadata) {
|
||||
const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) {
|
||||
DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
|
||||
const uint32_t lane_idx = ptx::get_lane_idx();
|
||||
|
||||
// Wait for primary kernel completion
|
||||
cudaGridDependencySynchronize();
|
||||
|
||||
__shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize];
|
||||
__shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize];
|
||||
__shared__ uint32_t varlen_num_atoms_shared;
|
||||
uint32_t num_items;
|
||||
|
||||
if constexpr (kIsVarlen) {
|
||||
if (lane_idx == 0) {
|
||||
uint32_t t = 0, atom_count = 0;
|
||||
while (t < batch_size) {
|
||||
varlen_atom_token_start[atom_count] = t;
|
||||
const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]);
|
||||
varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t];
|
||||
t += is_paired ? 2 : 1;
|
||||
++ atom_count;
|
||||
}
|
||||
varlen_num_atoms_shared = atom_count;
|
||||
}
|
||||
__syncwarp();
|
||||
num_items = varlen_num_atoms_shared;
|
||||
} else {
|
||||
num_items = batch_size;
|
||||
}
|
||||
|
||||
// Compute num_segs and prefix sum
|
||||
uint32_t num_segs[kAlignedBatchSize / 32];
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
|
||||
const uint32_t q_idx = k * 32 + lane_idx;
|
||||
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
|
||||
const uint32_t context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0);
|
||||
uint32_t context_len;
|
||||
if constexpr (kIsVarlen) {
|
||||
context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0);
|
||||
} else {
|
||||
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
|
||||
context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0);
|
||||
}
|
||||
num_segs[k] = math::ceil_div(context_len, SPLIT_KV);
|
||||
}
|
||||
|
||||
@@ -40,44 +69,118 @@ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t ne
|
||||
sum = __shfl_sync(0xffffffff, x, 31);
|
||||
}
|
||||
|
||||
const uint32_t num_next_n_atoms = next_n / ((next_n % 2 == 0) ? 2 : 1);
|
||||
const uint32_t total = sum * num_next_n_atoms;
|
||||
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
|
||||
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
|
||||
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
|
||||
uint32_t q_idx = 0;
|
||||
while (q_idx < batch_size and prefix_sum[q_idx] * num_next_n_atoms <= seg_starts)
|
||||
++ q_idx;
|
||||
const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms);
|
||||
const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]);
|
||||
const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0;
|
||||
const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0;
|
||||
const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx;
|
||||
__syncwarp();
|
||||
// SM work distribution
|
||||
if constexpr (kIsVarlen) {
|
||||
const uint32_t total = sum;
|
||||
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
|
||||
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
|
||||
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
|
||||
uint32_t lo = 0, hi = num_items;
|
||||
while (lo < hi) {
|
||||
const uint32_t mid = (lo + hi) / 2;
|
||||
const bool pred = prefix_sum[mid] <= seg_starts;
|
||||
lo = pred ? mid + 1 : lo;
|
||||
hi = pred ? hi : mid;
|
||||
}
|
||||
const uint32_t atom_idx = lo;
|
||||
const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]);
|
||||
const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size);
|
||||
__syncwarp();
|
||||
|
||||
schedule_metadata[sm_idx * 2] = q_atom_idx;
|
||||
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
|
||||
schedule_metadata[sm_idx * 2] = q_atom_idx;
|
||||
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
|
||||
}
|
||||
} else {
|
||||
const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1;
|
||||
const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom);
|
||||
const uint32_t total = sum * num_next_n_atoms;
|
||||
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
|
||||
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
|
||||
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
|
||||
uint32_t lo = 0, hi = batch_size;
|
||||
while (lo < hi) {
|
||||
const uint32_t mid = (lo + hi) / 2;
|
||||
const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts;
|
||||
lo = pred ? mid + 1 : lo;
|
||||
hi = pred ? hi : mid;
|
||||
}
|
||||
const uint32_t q_idx = lo;
|
||||
const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms);
|
||||
const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]);
|
||||
const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0;
|
||||
const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0;
|
||||
const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx;
|
||||
__syncwarp();
|
||||
|
||||
schedule_metadata[sm_idx * 2] = q_atom_idx;
|
||||
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t kNextN, bool kIsContextLens2D,
|
||||
// Conditional storage for varlen indices pointer (EBO: zero cost when unused)
|
||||
template <bool kHasIndices>
|
||||
struct IndicesStorage {
|
||||
const uint32_t* indices;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IndicesStorage<false> {};
|
||||
|
||||
template <uint32_t kNextN, bool kIsContextLens2D, bool kIsVarlen,
|
||||
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit,
|
||||
uint32_t kNumNextNAtoms>
|
||||
struct PagedMQALogitsScheduler {
|
||||
struct PagedMQALogitsScheduler : IndicesStorage<kIsVarlen> {
|
||||
const uint32_t* context_lens;
|
||||
uint32_t batch_size;
|
||||
|
||||
uint32_t current_q_atom_idx, current_kv_idx;
|
||||
uint32_t end_q_atom_idx, end_kv_idx;
|
||||
uint32_t current_num_kv;
|
||||
|
||||
CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const {
|
||||
const uint32_t q_idx = q_atom_idx / kNumNextNAtoms;
|
||||
const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
|
||||
return math::ceil_div(context_lens[lens_idx], BLOCK_KV);
|
||||
CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) {
|
||||
if constexpr (kIsVarlen) {
|
||||
return q_atom_idx;
|
||||
} else {
|
||||
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
|
||||
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
|
||||
if constexpr (kPadOddN) {
|
||||
return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom;
|
||||
} else {
|
||||
return q_atom_idx * kNextNAtom;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t* context_lens, const uint32_t* schedule_meta) {
|
||||
CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) {
|
||||
if constexpr (kIsVarlen) {
|
||||
return q_atom_idx;
|
||||
} else {
|
||||
return q_atom_idx / kNumNextNAtoms;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const {
|
||||
if constexpr (kIsVarlen) {
|
||||
const bool is_paired = (q_atom_idx + 1 < batch_size and
|
||||
this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]);
|
||||
const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx];
|
||||
return math::ceil_div(ctx_len, BLOCK_KV);
|
||||
} else {
|
||||
const uint32_t q_idx = q_atom_idx / kNumNextNAtoms;
|
||||
const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
|
||||
return math::ceil_div(context_lens[lens_idx], BLOCK_KV);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size,
|
||||
const uint32_t* context_lens,
|
||||
const uint32_t* schedule_meta, const uint32_t* indices) {
|
||||
this->context_lens = context_lens;
|
||||
this->batch_size = batch_size;
|
||||
if constexpr (kIsVarlen) {
|
||||
this->indices = indices;
|
||||
}
|
||||
|
||||
const auto current_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx];
|
||||
const auto end_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx + 1];
|
||||
@@ -87,6 +190,28 @@ struct PagedMQALogitsScheduler {
|
||||
current_num_kv = get_num_kv(current_q_atom_idx);
|
||||
}
|
||||
|
||||
// Advance step in q_atom_idx space when moving to the next atom.
|
||||
// Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence.
|
||||
// Non-varlen: always 1 (one atom unit).
|
||||
CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const {
|
||||
if constexpr (kIsVarlen) {
|
||||
return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Whether num_kv should be refreshed after advancing to q_atom_idx.
|
||||
// Varlen: always refresh (each atom may have a different context_len).
|
||||
// Non-varlen: only at atom-group boundaries (atoms within a group share context_len).
|
||||
CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const {
|
||||
if constexpr (kIsVarlen) {
|
||||
return true;
|
||||
} else {
|
||||
return q_atom_idx % kNumNextNAtoms == 0;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) {
|
||||
q_atom_idx = current_q_atom_idx;
|
||||
kv_idx = current_kv_idx;
|
||||
@@ -97,9 +222,9 @@ struct PagedMQALogitsScheduler {
|
||||
|
||||
current_kv_idx += kNumBlocksPerSplit;
|
||||
if (current_kv_idx >= current_num_kv) {
|
||||
++ current_q_atom_idx;
|
||||
current_kv_idx = 0;
|
||||
if (current_q_atom_idx % kNumNextNAtoms == 0 and exist_q_atom_idx(current_q_atom_idx)) {
|
||||
current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx);
|
||||
if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) {
|
||||
current_num_kv = get_num_kv(current_q_atom_idx);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user