Add various optimizations and Mega MoE benchmarks (#316)

* Merge with private repo

* Add Mega MoE Benchmark

* Minor fix

* Update

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
Zhean Xu
2026-04-24 18:41:37 +08:00
committed by GitHub
parent 7f2a703ed5
commit 891d57b4db
21 changed files with 1276 additions and 372 deletions

View File

@@ -1,11 +1,20 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/layout/sym_buffer.cuh>
#include <deep_gemm/layout/mega_moe.cuh>
namespace deep_gemm::comm {
CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() {
// Perform cluster_sync with `barrier.cluster.arrive.relaxed`
// This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee
cute::cluster_arrive_relaxed();
cute::cluster_wait();
}
template <uint32_t kNumSMs, uint32_t kGridSyncIndex = 0, typename sync_scope_t>
CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace,
const uint32_t& sm_idx, const uint32_t& thread_idx,

View File

@@ -401,7 +401,8 @@ void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
// Load accumulator from TMEM
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
tmem_load(cute::Int<kNumHeads>{}, tmem_addr, accum);
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
// Release TMEM empty
if (i == BLOCK_Q - 1) {

View File

@@ -20,7 +20,7 @@ namespace deep_gemm {
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D,
bool kIsContextLens2D, bool kIsVarlen,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
@@ -30,7 +30,8 @@ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* schedule_meta,
const uint32_t* block_table, const uint32_t* indices,
const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
@@ -54,10 +55,10 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
}
// Next-N atom configs
static constexpr uint32_t kNextNAtom = (kNextN % 2 == 0) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = kNextN / kNextNAtom;
static constexpr bool kSingleAtom = (kNumNextNAtoms == 1);
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
// UMMA configs
static constexpr uint32_t kNumTmemStages = 3;
@@ -157,7 +158,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
// Scheduler
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
// Make Q, KV and TMEM pipeline
@@ -182,7 +183,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
if (cute::elect_one_sync()) {
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
// Persistently schedule over blocks
// Initialize outside valid range to indicate no previous task
@@ -196,11 +197,12 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
// Issue TMA Q
const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx);
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_q[q_stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_atom_idx * kNextNAtom);
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_atom_idx * kNextNAtom);
smem_q[q_stage_idx], 0, q_token_idx * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx);
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx);
full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
}
last_q_atom_idx = q_atom_idx;
@@ -210,7 +212,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
} else if (warp_idx == kSpecWarpStart + 1) {
// TMA warp for loading KV cache
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
// Persistently schedule over blocks
uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage;
@@ -225,10 +227,11 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
// Coalesced load of block table
if (kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast<uint64_t>(block_table_stride);
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
}
__syncwarp();
// Broadcast KV block indices
int kv_block_idx[kNumBlocksPerSplit];
@@ -240,7 +243,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
// Wait KV consumer release
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
// Issue TMA KV
if (cute::elect_one_sync()) {
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
@@ -260,7 +263,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
} else if (warp_idx == kSpecWarpStart + 2) {
// UMMA warp
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// UTCCP transposer
@@ -371,7 +374,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
} else if (warp_idx < kSpecWarpStart) {
// Math warpgroups for reduce
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
const auto math_warpgroup_idx = warpgroup_idx;
const auto math_thread_idx = warp_idx * 32 + lane_idx;
@@ -400,6 +403,7 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
// Persistently schedule over blocks
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
uint32_t q_atom_idx, kv_idx, _;
bool is_paired_atom = false;
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
if (q_atom_idx != last_q_atom_idx) {
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
@@ -423,11 +427,16 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
weights[i][j + 3] = raw.w;
}
}
// Check if this atom pairs two tokens from the same sequence
if constexpr (kIsVarlen) {
is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2);
}
}
last_q_atom_idx = q_atom_idx;
// Calculate KV offset in advance
auto kv_offset = q_atom_idx * kNextNAtom * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
// Advance pipeline by `kNumMathWarpGroups` steps
// Wait UMMA arrival
@@ -436,53 +445,58 @@ void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
ptx::tcgen05_after_thread_sync();
// Reduce over the head dim and store
#pragma unroll
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
// Load accumulator from TMEM
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
tmem_load(cute::Int<kNumHeads>{}, tmem_addr, accum);
const auto reduce_and_store = [&](auto num_iters_c) {
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
// Only loop over valid iterations
#pragma unroll
for (uint32_t i = 0; i < kNumIters; ++ i) {
// Load accumulator from TMEM
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
// Store into the global memory
logits[kv_offset + i * static_cast<uint64_t>(logits_stride)] = result;
__syncwarp();
}
// Release TMEM empty
if (i == kNextNAtom - 1) {
ptx::tcgen05_before_thread_sync();
empty_tmem_barriers[tmem_stage_idx]->arrive();
}
ptx::tcgen05_before_thread_sync();
empty_tmem_barriers[tmem_stage_idx]->arrive();
};
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
// Store into the global memory
const auto dst_offset = kv_offset + i * static_cast<uint64_t>(logits_stride);
if constexpr(sizeof(logits_dtype_t) == 2) {
// Pack two adjacent bf16 lanes into uint32 for wider store
uint16_t my_bits = *reinterpret_cast<const uint16_t*>(&result);
uint16_t neighbor_bits = __shfl_down_sync(0xffffffff, my_bits, 1);
uint32_t packed;
asm volatile("mov.b32 %0, {%1, %2};" : "=r"(packed) : "h"(my_bits), "h"(neighbor_bits));
if (lane_idx % 2 == 0)
*reinterpret_cast<uint32_t*>(logits + dst_offset) = packed;
} else {
logits[dst_offset] = result;
}
// this sync warp prevent the next load tmem from reordering
// nvcc may reorder it to overlap with the current tmem load, lead to large register usage
__syncwarp();
if constexpr (kIsVarlen) {
if (is_paired_atom)
reduce_and_store(cute::Int<kNextNAtom>{});
else
reduce_and_store(cute::Int<1>{});
} else if constexpr (kPadOddN) {
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
reduce_and_store(cute::Int<1>{});
else
reduce_and_store(cute::Int<kNextNAtom>{});
} else {
reduce_and_store(cute::Int<kNextNAtom>{});
}
}

View File

@@ -48,6 +48,7 @@ template <
>
CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void
sm100_fp8_fp4_mega_moe_impl(void* y,
int* cumulative_local_expert_recv_stats,
const uint32_t num_tokens,
const __grid_constant__ layout::SymBuffer<kNumRanks> sym_buffer,
const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts,
@@ -91,7 +92,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Workspaces
const auto workspace = layout::Workspace(
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk, BLOCK_M);
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk);
// Token and buffer layouts
constexpr auto fp8_token_layout = layout::Data(kHidden);
@@ -170,7 +171,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
constexpr uint32_t UMMA_K = 32;
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N;
DG_STATIC_ASSERT(BLOCK_M % 32 == 0, "Invalid block M");
DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M");
DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N");
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
@@ -269,7 +270,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2);
// A cluster sync is essential for 2CTA tensor memory allocation
cute::cluster_sync();
comm::cluster_sync_with_relaxed_arrive();
// Initialization
if (warp_idx == 0) {
@@ -307,7 +308,9 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
cute::cluster_sync();
// NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`,
// and `barrier.cluster.wait.aligned` is by default `.acquire`
comm::cluster_sync_with_relaxed_arrive();
// Task scheduler
auto scheduler = sched::MegaMoEScheduler<
@@ -599,7 +602,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
__syncwarp();
}
// Clean workspace for the next usage
// Clean workspace for the next usage, and also do cumulative stats
// NOTES: it is overlapped with combine reduction epilogue
ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx);
@@ -623,19 +626,27 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Wait read count ready
ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx);
// Clean expert token count
if (thread_idx == 0)
// Clean expert token count, and add cumulative results
DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps");
if (warp_idx == 0) {
*workspace.get_expert_recv_count_sum_ptr(i) = 0;
} else if (warp_idx == 1) {
if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr)
ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast<int>(num_recv_tokens));
__syncwarp();
}
// Clean per-rank token count
for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads)
*workspace.get_expert_recv_count_ptr(j, i) = 0;
__syncwarp();
// Clean L1 and L2 arrival stuffs
for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) {
*workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0;
*workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0;
}
__syncwarp();
}
}
@@ -672,23 +683,22 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx);
const auto expected = scheduler.template get_valid_m<false>();
while (ptx::ld_acq(ptr) != expected);
} else {
// The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival
// NOTES: Originally we wait blocks on-demand to overlap L1 calculation
// with L2, but this optimization is negative when `num_experts_per_wave`
// guarantees L1's completion when L2 starts. So we remove it.
// In the future, if `num_experts_per_wave` is not large enough
// due to small `num_experts_per_rank`, we may need to add it back or add a switch
DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes");
const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx);
// NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts
// to avoid undefined behavior when `num_k_blocks == 32`
const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1;
while (ptx::ld_acq_gpu(ptr) != expected);
}
uint64_t cached_l2_arrival_mask = 0;
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) {
// Wait current K block arrival
if (block_phase == sched::BlockPhase::Linear2) {
// The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2 L1 blocks' arrival
DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes");
const uint64_t needed = 3ull << (k_block_idx * 2);
if ((cached_l2_arrival_mask & needed) != needed) {
const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx);
do {
cached_l2_arrival_mask = ptx::ld_acq_gpu(ptr);
} while ((cached_l2_arrival_mask & needed) != needed);
}
}
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
@@ -953,8 +963,7 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Load weights from global into register cache per 32 tokens
DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size");
DG_STATIC_ASSERT(WG_BLOCK_M % 32 == 0, "Invalid block size");
if ((j * ATOM_M) % 32 == 0) {
if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) {
stored_cached_weight = *l1_topk_weights_buffer
.get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx)
.get_base_ptr<float>();
@@ -1060,19 +1069,26 @@ sm100_fp8_fp4_mega_moe_impl(void* y,
// Only one warp per pair writes (both hold the same SF after cross-warp reduce)
// Each lane < 4 holds SF for 2 rows (sf.x and sf.y)
if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {
// TODO: I believe the expression can be optimized
const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M
+ epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2;
const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2;
const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4;
const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t);
const auto sf_base_ptr = l2_sf_buffer.get_base_ptr<uint8_t>();
// NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4
// NOTES: originally there was:
// - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2
// - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)`
// We find out that
// 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside
// 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside
// This reduce the number of computation instructions.
const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
__builtin_assume(token_base_idx < BLOCK_M);
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
+ transform_sf_token_idx(token_idx_in_expert);
sf_base_ptr[k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx] =
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
sf_base_ptr[sf_addr] =
(*reinterpret_cast<const uint32_t*>(&sf.x) >> 23);
sf_base_ptr[k_uint_idx * mn_stride + (sf_pool_token_idx + 4) * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx] =
sf_base_ptr[sf_addr + 4 * static_cast<uint32_t>(sizeof(uint32_t))] =
(*reinterpret_cast<const uint32_t*>(&sf.y) >> 23);
}
__syncwarp();

View File

@@ -20,7 +20,7 @@ namespace deep_gemm {
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D,
bool kIsContextLens2D, bool kIsVarlen,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
@@ -30,7 +30,8 @@ CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* schedule_meta,
const uint32_t* block_table, const uint32_t* indices,
const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
@@ -53,10 +54,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
cute::prefetch_tma_descriptor(&tensor_map_weights);
}
// Next-N atom configs
static constexpr uint32_t kNextNAtom = (kNextN % 2 == 0) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = kNextN / kNextNAtom;
static constexpr bool kSingleAtom = (kNumNextNAtoms == 1);
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
@@ -136,7 +137,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Scheduler
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
// Q and KV pipeline
@@ -157,13 +158,14 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
if (warp_idx == kSpecWarpStart) {
// TMA warp for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_atom_idx) {
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) {
if (cute::elect_one_sync()) {
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_atom_idx * kNextNAtom);
const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx);
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
}
};
@@ -182,7 +184,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
while (fetched_next_task) {
// Prefetch next Q when (q, atom) changes
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + 1);
const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size);
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance);
if (q_atom_idx != next_q_atom_idx)
kv_block_idx_ptr = 32;
@@ -195,17 +198,18 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// TODO(xuzhean): consider -1
if (kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast<uint64_t>(block_table_stride);
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
}
__syncwarp();
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
// Wait Q consumer release and issue TMA Q
if (prefetch_q) {
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
issue_tma_q(q_stage_idx, q_atom_idx + 1);
issue_tma_q(q_stage_idx, q_atom_idx + next_advance);
}
uint32_t kv_block_idx[kNumBlocksPerSplit];
@@ -236,7 +240,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
}
} else if (warp_idx == kSpecWarpStart + 1) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// Require full allocation
@@ -292,7 +296,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
} else if (warp_idx < kSpecWarpStart) {
// Math warpgroups for reduce
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// Offsets
@@ -321,6 +325,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
uint32_t q_stage_idx, q_phase;
uint32_t umma_phase = 0;
bool is_paired_atom = false;
while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
// Q or atom changes
@@ -340,6 +345,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
for (uint32_t j = 0; j < kNumHeads; ++ j)
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
if constexpr (kIsVarlen) {
is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2);
}
}
// Get current task indices
@@ -347,7 +356,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
kv_idx = next_kv_idx;
// Calculate KV offset in advance
auto kv_offset = q_atom_idx * kNextNAtom * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
@@ -367,40 +376,56 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Reduce over the head dim and store
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
#pragma unroll
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
// Load accumulator from TMEM
const auto reduce_and_store = [&](auto num_iters_c) {
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
float accum[kNumHeads];
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
// Release TMEM empty
if (i == kNextNAtom - 1) {
ptx::tcgen05_before_thread_sync();
empty_umma_barriers[math_warpgroup_idx]->arrive();
}
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
for (uint32_t i = 0; i < kNumIters; ++ i) {
// Load accumulator from TMEM
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
// Store into the global memory
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
__syncwarp();
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
// Release TMEM empty
ptx::tcgen05_before_thread_sync();
empty_umma_barriers[math_warpgroup_idx]->arrive();
};
// Store into the global memory
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
__syncwarp();
if constexpr (kIsVarlen) {
if (is_paired_atom)
reduce_and_store(cute::Int<kNextNAtom>{});
else
reduce_and_store(cute::Int<1>{});
} else if constexpr (kPadOddN) {
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
reduce_and_store(cute::Int<1>{});
else
reduce_and_store(cute::Int<kNextNAtom>{});
} else {
reduce_and_store(cute::Int<kNextNAtom>{});
}
}

View File

@@ -21,7 +21,7 @@ namespace deep_gemm {
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D,
bool kIsContextLens2D, bool kIsVarlen,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
@@ -30,11 +30,14 @@ CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* schedule_meta,
const uint32_t* block_table, const uint32_t* indices,
const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits");
// Types
using WGMMA = typename mma::sm90::FP8MMASelector<kNextN * kNumHeads>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
@@ -132,8 +135,8 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
cudaGridDependencySynchronize();
// Scheduler
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups, 1>(
blockIdx.x, context_lens, schedule_meta);
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumMathWarpGroups, 1>(
blockIdx.x, batch_size, context_lens, schedule_meta, indices);
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
// Q and KV pipeline

View File

@@ -1,19 +1,27 @@
#pragma once
#include <cute/numeric/math.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::layout {
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding
static constexpr int kNumCandidateBlockMs = 7;
static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192};
static constexpr int kMaxCandidateBlockM = 192;
static constexpr int kMinCandidateBlockM = 8;
static constexpr int kLCMCandidateBlockM = 384;
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M
template <typename T>
CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk,
T num_experts_per_rank, T block_m) {
T num_experts_per_rank) {
const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank;
const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank);
return math::constexpr_align(
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (block_m - 1),
block_m);
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast<T>(kMaxCandidateBlockM) - 1),
static_cast<T>(kLCMCandidateBlockM));
}
// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M
@@ -48,17 +56,14 @@ struct Workspace {
const uint32_t& num_ranks,
const uint32_t& num_experts,
const uint32_t& num_max_tokens_per_rank,
const uint32_t& num_topk,
const uint32_t& block_m):
const uint32_t& num_topk):
base(base),
num_ranks(num_ranks), num_experts(num_experts),
num_max_tokens_per_rank(num_max_tokens_per_rank) {
num_experts_per_rank = num_experts / num_ranks;
num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank;
num_max_pool_tokens = get_num_max_pool_tokens(
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank, block_m);
num_max_pool_blocks = num_max_pool_tokens / block_m;
DG_UNIFIED_ASSERT(num_max_tokens_per_rank % block_m == 0);
num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM;
}
CUTLASS_HOST_DEVICE

View File

@@ -164,7 +164,7 @@ CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) {
}
CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) {
asm volatile("st.L1::no_allocate.relaxed.sys.u64 [%0], %1;" :: "l"(ptr), "l"(value));
asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value));
}
/// Atomics
@@ -186,7 +186,11 @@ CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& valu
return ret;
}
__forceinline__ __device__ void red_add(const uint32_t* ptr, const uint32_t& value) {
CUTLASS_DEVICE void red_add(const int* ptr, const int& value) {
asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value));
}
CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) {
asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value));
}

View File

@@ -6,22 +6,51 @@
namespace deep_gemm::sched {
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs, bool kIsVarlen = false>
CUTLASS_GLOBAL __launch_bounds__(32, 1)
void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
const uint32_t* context_lens, uint32_t* schedule_metadata) {
const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) {
DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
const uint32_t lane_idx = ptx::get_lane_idx();
// Wait for primary kernel completion
cudaGridDependencySynchronize();
__shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize];
__shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize];
__shared__ uint32_t varlen_num_atoms_shared;
uint32_t num_items;
if constexpr (kIsVarlen) {
if (lane_idx == 0) {
uint32_t t = 0, atom_count = 0;
while (t < batch_size) {
varlen_atom_token_start[atom_count] = t;
const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]);
varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t];
t += is_paired ? 2 : 1;
++ atom_count;
}
varlen_num_atoms_shared = atom_count;
}
__syncwarp();
num_items = varlen_num_atoms_shared;
} else {
num_items = batch_size;
}
// Compute num_segs and prefix sum
uint32_t num_segs[kAlignedBatchSize / 32];
#pragma unroll
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
const uint32_t q_idx = k * 32 + lane_idx;
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
const uint32_t context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0);
uint32_t context_len;
if constexpr (kIsVarlen) {
context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0);
} else {
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0);
}
num_segs[k] = math::ceil_div(context_len, SPLIT_KV);
}
@@ -40,44 +69,118 @@ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t ne
sum = __shfl_sync(0xffffffff, x, 31);
}
const uint32_t num_next_n_atoms = next_n / ((next_n % 2 == 0) ? 2 : 1);
const uint32_t total = sum * num_next_n_atoms;
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
uint32_t q_idx = 0;
while (q_idx < batch_size and prefix_sum[q_idx] * num_next_n_atoms <= seg_starts)
++ q_idx;
const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms);
const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]);
const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0;
const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0;
const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx;
__syncwarp();
// SM work distribution
if constexpr (kIsVarlen) {
const uint32_t total = sum;
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
uint32_t lo = 0, hi = num_items;
while (lo < hi) {
const uint32_t mid = (lo + hi) / 2;
const bool pred = prefix_sum[mid] <= seg_starts;
lo = pred ? mid + 1 : lo;
hi = pred ? hi : mid;
}
const uint32_t atom_idx = lo;
const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]);
const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size);
__syncwarp();
schedule_metadata[sm_idx * 2] = q_atom_idx;
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
schedule_metadata[sm_idx * 2] = q_atom_idx;
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
}
} else {
const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1;
const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom);
const uint32_t total = sum * num_next_n_atoms;
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
uint32_t lo = 0, hi = batch_size;
while (lo < hi) {
const uint32_t mid = (lo + hi) / 2;
const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts;
lo = pred ? mid + 1 : lo;
hi = pred ? hi : mid;
}
const uint32_t q_idx = lo;
const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms);
const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]);
const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0;
const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0;
const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx;
__syncwarp();
schedule_metadata[sm_idx * 2] = q_atom_idx;
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
}
}
}
template <uint32_t kNextN, bool kIsContextLens2D,
// Conditional storage for varlen indices pointer (EBO: zero cost when unused)
template <bool kHasIndices>
struct IndicesStorage {
const uint32_t* indices;
};
template <>
struct IndicesStorage<false> {};
template <uint32_t kNextN, bool kIsContextLens2D, bool kIsVarlen,
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit,
uint32_t kNumNextNAtoms>
struct PagedMQALogitsScheduler {
struct PagedMQALogitsScheduler : IndicesStorage<kIsVarlen> {
const uint32_t* context_lens;
uint32_t batch_size;
uint32_t current_q_atom_idx, current_kv_idx;
uint32_t end_q_atom_idx, end_kv_idx;
uint32_t current_num_kv;
CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const {
const uint32_t q_idx = q_atom_idx / kNumNextNAtoms;
const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
return math::ceil_div(context_lens[lens_idx], BLOCK_KV);
CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) {
if constexpr (kIsVarlen) {
return q_atom_idx;
} else {
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
if constexpr (kPadOddN) {
return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom;
} else {
return q_atom_idx * kNextNAtom;
}
}
}
CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t* context_lens, const uint32_t* schedule_meta) {
CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) {
if constexpr (kIsVarlen) {
return q_atom_idx;
} else {
return q_atom_idx / kNumNextNAtoms;
}
}
CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const {
if constexpr (kIsVarlen) {
const bool is_paired = (q_atom_idx + 1 < batch_size and
this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]);
const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx];
return math::ceil_div(ctx_len, BLOCK_KV);
} else {
const uint32_t q_idx = q_atom_idx / kNumNextNAtoms;
const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
return math::ceil_div(context_lens[lens_idx], BLOCK_KV);
}
}
CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size,
const uint32_t* context_lens,
const uint32_t* schedule_meta, const uint32_t* indices) {
this->context_lens = context_lens;
this->batch_size = batch_size;
if constexpr (kIsVarlen) {
this->indices = indices;
}
const auto current_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx];
const auto end_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx + 1];
@@ -87,6 +190,28 @@ struct PagedMQALogitsScheduler {
current_num_kv = get_num_kv(current_q_atom_idx);
}
// Advance step in q_atom_idx space when moving to the next atom.
// Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence.
// Non-varlen: always 1 (one atom unit).
CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const {
if constexpr (kIsVarlen) {
return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1;
} else {
return 1;
}
}
// Whether num_kv should be refreshed after advancing to q_atom_idx.
// Varlen: always refresh (each atom may have a different context_len).
// Non-varlen: only at atom-group boundaries (atoms within a group share context_len).
CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const {
if constexpr (kIsVarlen) {
return true;
} else {
return q_atom_idx % kNumNextNAtoms == 0;
}
}
CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) {
q_atom_idx = current_q_atom_idx;
kv_idx = current_kv_idx;
@@ -97,9 +222,9 @@ struct PagedMQALogitsScheduler {
current_kv_idx += kNumBlocksPerSplit;
if (current_kv_idx >= current_num_kv) {
++ current_q_atom_idx;
current_kv_idx = 0;
if (current_q_atom_idx % kNumNextNAtoms == 0 and exist_q_atom_idx(current_q_atom_idx)) {
current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx);
if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) {
current_num_kv = get_num_kv(current_q_atom_idx);
}
}