Add SM90 kernels (#200)
This commit is contained in:
317
deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh
Normal file
317
deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh
Normal file
@@ -0,0 +1,317 @@
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/mma_sm90_desc.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
// ReSharper disable once CppNotAllPathsReturnValue
|
||||
template <uint32_t kHeadDim>
|
||||
static constexpr int to_swizzle_cute_type() {
|
||||
DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling");
|
||||
if constexpr (kHeadDim == 32)
|
||||
return static_cast<int>(cute::SM90::GMMA::LayoutType::B32);
|
||||
if constexpr (kHeadDim == 64)
|
||||
return static_cast<int>(cute::SM90::GMMA::LayoutType::B64);
|
||||
if constexpr (kHeadDim == 128)
|
||||
return static_cast<int>(cute::SM90::GMMA::LayoutType::B128);
|
||||
}
|
||||
|
||||
template <uint32_t kNumHeads, uint32_t kHeadDim,
|
||||
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
||||
void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_kv,
|
||||
uint32_t* cu_seq_len_k_start,
|
||||
uint32_t* cu_seq_len_k_end,
|
||||
float* logits,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
||||
// TODO: consider TMA multicast
|
||||
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
|
||||
// Q should be load only at once for a block
|
||||
const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q);
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_Q * kNumHeads>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Prefetch TMA descriptors
|
||||
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
||||
if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_q);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Shared memory configs
|
||||
// NOTES: weight may be unaligned
|
||||
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
||||
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
|
||||
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
||||
|
||||
// Align to swizzling alignment bytes
|
||||
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
||||
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
|
||||
SMEM_Q_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
|
||||
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
|
||||
});
|
||||
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer +
|
||||
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer +
|
||||
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages +
|
||||
SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
||||
});
|
||||
|
||||
// TMA barriers
|
||||
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
||||
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
||||
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
|
||||
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
|
||||
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
|
||||
|
||||
// Initialize barriers
|
||||
const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32;
|
||||
if (is_tma_load_warp and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
||||
full_q_barriers[i]->init(1);
|
||||
empty_q_barriers[i]->init(kNumMathThreads);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
||||
full_kv_barriers[i]->init(1);
|
||||
empty_kv_barriers[i]->init(kNumMathThreads);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr uint32_t kNumTMARegisters = 32;
|
||||
constexpr uint32_t kNumMathRegisters = 112;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
|
||||
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {block_q_idx + gridDim.x, q_iter_idx + 1};
|
||||
};
|
||||
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
|
||||
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
||||
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
|
||||
start = min(start, min(__ldg(cu_seq_len_k_start + q_idx), seq_len_kv));
|
||||
end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv));
|
||||
}
|
||||
start = start / 4 * 4;
|
||||
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
|
||||
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
|
||||
start, ceil_div(end - start, BLOCK_KV)}; // Task info
|
||||
};
|
||||
|
||||
// KV pipeline
|
||||
uint32_t num_total_kv_blocks = 0;
|
||||
const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {
|
||||
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
|
||||
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
|
||||
};
|
||||
};
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// Only the first warp remains
|
||||
if (not is_tma_load_warp)
|
||||
return;
|
||||
|
||||
// Prefetch
|
||||
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
|
||||
tma_copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
|
||||
tma_copy(&tensor_map_weights, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
|
||||
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
||||
};
|
||||
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
|
||||
issue_tma_q(0, block_q_idx);
|
||||
|
||||
// Only the first lane persistently schedules over blocks
|
||||
if (cute::elect_one_sync()) {
|
||||
while (block_q_idx < num_q_blocks) {
|
||||
CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
||||
|
||||
// Wait Q consumer release
|
||||
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
||||
|
||||
// Issue TMA Q
|
||||
if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks)
|
||||
issue_tma_q(q_stage_idx, next_block_q_idx);
|
||||
|
||||
// Issue TMA KV
|
||||
#pragma unroll
|
||||
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
|
||||
// Wait consumer release
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
|
||||
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
||||
|
||||
// Issue TMA KV
|
||||
tma_copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
|
||||
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
|
||||
tma_copy(&tensor_map_kv_scales, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
|
||||
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
|
||||
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
||||
}
|
||||
num_total_kv_blocks += num_kv_blocks;
|
||||
|
||||
// Jump to the next block
|
||||
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto& thread_idx = threadIdx.x % kNumMathThreads;
|
||||
const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0);
|
||||
const auto& warpgroup_idx = warp_idx / 4;
|
||||
const auto& lane_idx = get_lane_idx();
|
||||
float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4];
|
||||
|
||||
const auto& warp_offset = warp_idx * 16;
|
||||
const auto& v_0_offset = lane_idx / 4 + 0;
|
||||
const auto& v_1_offset = lane_idx / 4 + 8;
|
||||
|
||||
while (block_q_idx < num_q_blocks) {
|
||||
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
||||
|
||||
// Wait TMA Q arrival
|
||||
full_q_barriers[q_stage_idx]->wait(q_phase);
|
||||
|
||||
// Read weights
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
||||
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
||||
}
|
||||
|
||||
// Compute over KV blocks
|
||||
#pragma unroll
|
||||
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
|
||||
// Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
|
||||
// Wait TMA KV arrival
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
|
||||
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
||||
|
||||
// Read per-KV scales
|
||||
float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
|
||||
float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
|
||||
|
||||
// Issue WGMMA
|
||||
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
|
||||
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K,
|
||||
to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
||||
auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K,
|
||||
to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Release KV empty
|
||||
empty_kv_barriers[kv_stage_idx]->arrive();
|
||||
|
||||
// Reduce over the head dim and store
|
||||
const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
|
||||
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation");
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation");
|
||||
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
auto shifted_accum = accum + i * kNumAccumPerReduce;
|
||||
const auto& transform = [&](const uint32_t& j) {
|
||||
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
|
||||
};
|
||||
|
||||
// Intra-thread reduction
|
||||
float sum[4] = {transform(0), transform(1), transform(2), transform(3)};
|
||||
#pragma unroll
|
||||
for (uint32_t j = 1; j < kNumHeads / 8; ++ j) {
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < 4; k ++)
|
||||
sum[k] += transform(j * 4 + k);
|
||||
}
|
||||
float v_0 = (sum[0] + sum[1]) * scale_kv_0;
|
||||
float v_1 = (sum[2] + sum[3]) * scale_kv_1;
|
||||
|
||||
// Inter-thread reduction
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++ j) {
|
||||
const auto& offset = static_cast<int>(1u << j);
|
||||
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
|
||||
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
|
||||
}
|
||||
|
||||
// Store into the global memory
|
||||
// NOTES: we have redundant writes here, consider more carefully
|
||||
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
|
||||
logits[q_idx * stride_kv + kv_offset + v_0_offset] = v_0;
|
||||
logits[q_idx * stride_kv + kv_offset + v_1_offset] = v_1;
|
||||
}
|
||||
}
|
||||
num_total_kv_blocks += num_kv_blocks;
|
||||
|
||||
// Release Q empty
|
||||
empty_q_barriers[q_stage_idx]->arrive();
|
||||
|
||||
// Jump to the next block
|
||||
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
403
deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh
Normal file
403
deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh
Normal file
@@ -0,0 +1,403 @@
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
|
||||
__global__ __launch_bounds__(32, 1)
|
||||
void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t* context_lens, uint32_t* schedule_metadata) {
|
||||
DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
|
||||
const uint32_t lane_idx = get_lane_idx();
|
||||
|
||||
uint32_t num_segs[kAlignedBatchSize / 32];
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
|
||||
const uint32_t& context_len = (k * 32 + lane_idx < batch_size ? __ldg(context_lens + k * 32 + lane_idx) : 0);
|
||||
num_segs[k] = ceil_div(context_len, SPLIT_KV);
|
||||
}
|
||||
|
||||
__shared__ uint32_t prefix_sum[kAlignedBatchSize];
|
||||
uint32_t sum = 0;
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
|
||||
uint32_t x = num_segs[k];
|
||||
#pragma unroll
|
||||
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
|
||||
const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset);
|
||||
x += (lane_idx >= offset ? y : 0);
|
||||
}
|
||||
x += sum;
|
||||
prefix_sum[k * 32 + lane_idx] = x;
|
||||
sum = __shfl_sync(0xffffffff, x, 31);
|
||||
}
|
||||
|
||||
const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs;
|
||||
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
|
||||
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
|
||||
uint32_t q_idx = 0;
|
||||
while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts)
|
||||
++ q_idx;
|
||||
const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]);
|
||||
__syncwarp();
|
||||
|
||||
schedule_metadata[sm_idx * 2] = q_idx;
|
||||
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t BLOCK_KV, uint32_t kNumMathWarpGroups>
|
||||
struct PagedMQALogitsScheduler {
|
||||
uint32_t batch_size;
|
||||
const uint32_t* context_lens;
|
||||
|
||||
uint32_t current_q_idx, current_kv_idx;
|
||||
uint32_t end_q_idx, end_kv_idx;
|
||||
uint32_t current_num_kv;
|
||||
|
||||
__device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx,
|
||||
const uint32_t* context_lens, const uint32_t* schedule_meta) {
|
||||
this->batch_size = batch_size;
|
||||
this->context_lens = context_lens;
|
||||
|
||||
const auto& current_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx);
|
||||
const auto& end_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx + 1);
|
||||
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumMathWarpGroups;
|
||||
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumMathWarpGroups;
|
||||
|
||||
current_num_kv = current_q_idx < batch_size ? ceil_div(__ldg(this->context_lens + current_q_idx), BLOCK_KV) : 0;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) {
|
||||
q_idx = current_q_idx;
|
||||
kv_idx = current_kv_idx;
|
||||
num_kv = current_num_kv;
|
||||
|
||||
if (q_idx == end_q_idx and kv_idx == end_kv_idx)
|
||||
return false;
|
||||
|
||||
current_kv_idx += kNumMathWarpGroups;
|
||||
if (current_kv_idx >= current_num_kv) {
|
||||
++ current_q_idx;
|
||||
current_kv_idx = 0;
|
||||
current_num_kv = current_q_idx < batch_size ? ceil_div(__ldg(this->context_lens + current_q_idx), BLOCK_KV) : 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const {
|
||||
return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx;
|
||||
}
|
||||
};
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
template <uint32_t kNextN, uint32_t kNumHeads,
|
||||
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t SPLIT_KV,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
||||
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
const uint64_t logits_stride, const uint64_t block_table_stride,
|
||||
const uint32_t* context_lens, float* logits,
|
||||
const uint32_t* block_table, const uint32_t* schedule_meta,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<kNextN * kNumHeads>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const auto& warpgroup_idx = warp_idx / 4;
|
||||
const auto& lane_idx = get_lane_idx();
|
||||
|
||||
// Prefetch TMA descriptors
|
||||
static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
|
||||
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
||||
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_q);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Shared memory configs
|
||||
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
||||
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
|
||||
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
|
||||
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
|
||||
constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
|
||||
|
||||
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
||||
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
|
||||
static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
|
||||
constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
|
||||
|
||||
// Align to swizzling alignment bytes
|
||||
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
||||
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
||||
|
||||
// Q data and barriers on shared memory
|
||||
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
||||
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
|
||||
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
|
||||
|
||||
// Separate math warpgroups and tma load warps into KV groups
|
||||
// Each math warpgroup corresponds to a tma load warp
|
||||
const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
|
||||
|
||||
// Per group KV data and barriers on shared memory
|
||||
const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
|
||||
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
||||
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
|
||||
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
|
||||
|
||||
// Initialize barriers
|
||||
if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
if (kv_group_idx == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
||||
full_q_barriers[i]->init(1);
|
||||
empty_q_barriers[i]->init(kNumMathThreads);
|
||||
}
|
||||
}
|
||||
if (kv_group_idx < kNumMathWarpGroups) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
||||
full_kv_barriers[i]->init(1);
|
||||
empty_kv_barriers[i]->init(128);
|
||||
}
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr uint32_t kNumTMARegisters = 64;
|
||||
constexpr uint32_t kNumMathRegisters = 104;
|
||||
|
||||
// Scheduler
|
||||
auto scheduler = PagedMQALogitsScheduler<BLOCK_KV, kNumMathWarpGroups>(batch_size, blockIdx.x, context_lens, schedule_meta);
|
||||
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
|
||||
|
||||
// Q and KV pipeline
|
||||
const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
|
||||
};
|
||||
const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
|
||||
};
|
||||
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
||||
|
||||
if (warp_idx >= kNumMathThreads / 32) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
if (kv_group_idx >= kNumMathWarpGroups)
|
||||
return;
|
||||
|
||||
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
|
||||
if (kv_group_idx == 0 and cute::elect_one_sync()) {
|
||||
tma_copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
|
||||
tma_copy(&tensor_map_weights, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, q_idx);
|
||||
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
||||
uint32_t q_idx = batch_size, kv_idx, num_kv;
|
||||
uint32_t next_q_idx, next_kv_idx, next_num_kv;
|
||||
bool fetched_next_task;
|
||||
|
||||
// Prefetch the first Q
|
||||
if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)))
|
||||
issue_tma_q(0, next_q_idx), q_iter_idx = 1;
|
||||
|
||||
int kv_block_idx_ptr = 32;
|
||||
uint32_t kv_block_idx_storage;
|
||||
|
||||
while (fetched_next_task) {
|
||||
// Prefetch next Q when current Q changes
|
||||
bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1));
|
||||
q_idx = next_q_idx;
|
||||
kv_idx = next_kv_idx;
|
||||
num_kv = next_num_kv;
|
||||
|
||||
// Wait Q consumer release and issue TMA Q
|
||||
if (prefetch_q) {
|
||||
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
||||
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
||||
issue_tma_q(q_stage_idx, q_idx + 1);
|
||||
}
|
||||
|
||||
// Read KV block index
|
||||
// TODO: deal with `-1`?
|
||||
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
|
||||
kv_block_idx_ptr = 0;
|
||||
kv_block_idx_storage = (kv_idx + kv_group_idx + + lane_idx * kNumMathWarpGroups < num_kv ?
|
||||
__ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0);
|
||||
}
|
||||
const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
|
||||
|
||||
// Wait KV consumer release
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
||||
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
||||
|
||||
// Issue TMA KV
|
||||
if (cute::elect_one_sync()) {
|
||||
tma_3d_copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
|
||||
smem_kv[kv_stage_idx], 0, 0, kv_block_idx);
|
||||
tma_copy(&tensor_map_kv_scales, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
|
||||
smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
|
||||
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fetch next task
|
||||
fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv);
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4];
|
||||
const auto& sub_warp_offset = (warp_idx % 4) * 16;
|
||||
const auto& v_0_offset = lane_idx / 4 + 0;
|
||||
const auto& v_1_offset = lane_idx / 4 + 8;
|
||||
|
||||
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
||||
uint32_t q_idx = batch_size, kv_idx;
|
||||
uint32_t next_q_idx, next_kv_idx, next_num_kv;
|
||||
uint32_t q_stage_idx, q_phase;
|
||||
|
||||
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
|
||||
// Current Q changes
|
||||
if (q_idx != next_q_idx) {
|
||||
// Release Last Q empty
|
||||
if (q_iter_idx > 0)
|
||||
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
|
||||
|
||||
// Wait TMA Q arrival
|
||||
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
||||
full_q_barriers[q_stage_idx]->wait(q_phase);
|
||||
|
||||
// Read weights
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNextN; ++ i) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
||||
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
||||
}
|
||||
}
|
||||
|
||||
// Get current Q and KV index
|
||||
q_idx = next_q_idx;
|
||||
kv_idx = next_kv_idx;
|
||||
|
||||
// Calculate KV offset in advance
|
||||
auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
|
||||
|
||||
// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
|
||||
// Wait TMA KV arrival
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
||||
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
||||
|
||||
// Issue WGMMA
|
||||
DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size");
|
||||
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
||||
auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
|
||||
// Read per-KV scales
|
||||
float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
|
||||
float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
|
||||
|
||||
// Wait WGMMA
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Release KV empty
|
||||
empty_kv_barriers[kv_stage_idx]->arrive();
|
||||
|
||||
// Reduce over the head dim and store
|
||||
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation");
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation");
|
||||
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNextN; ++ i) {
|
||||
auto shifted_accum = accum + i * kNumAccumPerReduce;
|
||||
const auto& transform = [&](const uint32_t& j) {
|
||||
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
|
||||
};
|
||||
|
||||
// Intra-thread reduction
|
||||
float sum[4] = {transform(0), transform(1), transform(2), transform(3)};
|
||||
#pragma unroll
|
||||
for (uint32_t j = 1; j < kNumHeads / 8; ++ j) {
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < 4; k ++)
|
||||
sum[k] += transform(j * 4 + k);
|
||||
}
|
||||
float v_0 = (sum[0] + sum[1]) * scale_kv_0;
|
||||
float v_1 = (sum[2] + sum[3]) * scale_kv_1;
|
||||
|
||||
// Inter-thread reduction
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++ j) {
|
||||
const auto& offset = static_cast<int>(1u << j);
|
||||
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
|
||||
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
|
||||
}
|
||||
|
||||
// Store into the global memory
|
||||
// NOTES: we have redundant writes here, consider more carefully
|
||||
logits[kv_offset + i * logits_stride + v_0_offset] = v_0;
|
||||
logits[kv_offset + i * logits_stride + v_1_offset] = v_1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
67
deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh
Normal file
67
deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh
Normal file
@@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps>
|
||||
__global__ __launch_bounds__(kNumWarps * 32, 1)
|
||||
void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_kv,
|
||||
const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) {
|
||||
const uint32_t& num_sms = gridDim.x;
|
||||
const uint32_t& sm_idx = blockIdx.x;
|
||||
const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
constexpr float neg_inf = -cute::numeric_limits<float>::infinity();
|
||||
|
||||
// Allocate filled `-inf` shared memory
|
||||
extern __shared__ __align__(1024) float smem_buffer[];
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32)
|
||||
smem_buffer[i] = neg_inf;
|
||||
cute::tma_store_fence();
|
||||
__syncthreads();
|
||||
|
||||
// Assign sequence to each warp
|
||||
const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx,
|
||||
const uint32_t& start, const uint32_t& total) -> cute::tuple<uint32_t, uint32_t> {
|
||||
const auto& per = total / num, rem = total % num;
|
||||
return {start + idx * per + min(idx, rem), per + (idx < rem)};
|
||||
};
|
||||
CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len);
|
||||
CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len);
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
|
||||
const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
|
||||
const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
|
||||
const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
|
||||
|
||||
for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) {
|
||||
const auto& right = min(left + BLOCK_KV, static_cast<uint32_t>(stride_kv));
|
||||
if (right <= ks or ke <= left) {
|
||||
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + left, (right - left) * sizeof(float));
|
||||
} else {
|
||||
if (left < aligned_ks)
|
||||
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + left, (aligned_ks - left) * sizeof(float));
|
||||
if (aligned_ke < right)
|
||||
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + aligned_ke, (right - aligned_ke) * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
|
||||
const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
|
||||
const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
|
||||
const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
|
||||
for (uint32_t j = aligned_ks; j < ks; ++ j)
|
||||
logits[i * stride_kv + j] = neg_inf;
|
||||
for (uint32_t j = ke; j < aligned_ke; ++ j)
|
||||
logits[i * stride_kv + j] = neg_inf;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user