Use larger MMA shape (#227)
This commit is contained in:
@@ -87,7 +87,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
|
||||
constexpr int block_qh = 128;
|
||||
constexpr int block_kv = 256;
|
||||
constexpr int num_specialized_threads = 128;
|
||||
constexpr int num_math_threads = 512;
|
||||
constexpr int num_math_threads = 256;
|
||||
constexpr int num_q_stages = 3, num_kv_stages = 3;
|
||||
const int block_q = block_qh / num_heads;
|
||||
DG_HOST_ASSERT(block_qh % num_heads == 0);
|
||||
|
||||
@@ -134,7 +134,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr uint32_t kNumSpecializedRegisters = 32;
|
||||
constexpr uint32_t kNumMathRegisters = 112;
|
||||
constexpr uint32_t kNumMathRegisters = 232;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
|
||||
@@ -168,7 +168,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
|
||||
|
||||
// UMMA settings
|
||||
// Construct instruction with layout F
|
||||
constexpr uint32_t UMMA_M = 64;
|
||||
constexpr uint32_t UMMA_M = 128;
|
||||
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
|
||||
constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
|
||||
|
||||
@@ -242,7 +242,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
|
||||
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
||||
|
||||
// Issue UMMA
|
||||
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
|
||||
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size");
|
||||
DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
||||
@@ -270,10 +270,11 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
|
||||
|
||||
// Offsets
|
||||
const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
|
||||
float weights[BLOCK_Q][kNumHeads / 4];
|
||||
const auto& warp_offset = warp_idx * 16;
|
||||
const auto& v_0_offset = lane_idx / 4 + 0;
|
||||
const auto& v_1_offset = lane_idx / 4 + 8;
|
||||
// Number of registers is not sufficient for all weights, so we need to cache some of them
|
||||
constexpr uint32_t kNumWeightsInReg = 48;
|
||||
float weights[BLOCK_Q][kNumWeightsInReg];
|
||||
const auto& warp_offset = warp_idx * 32;
|
||||
const auto& v_offset = lane_idx;
|
||||
|
||||
while (block_q_idx < num_q_blocks) {
|
||||
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
||||
@@ -284,9 +285,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
|
||||
// Read weights
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
||||
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
||||
for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) {
|
||||
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute over KV blocks
|
||||
@@ -298,8 +299,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
|
||||
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
||||
|
||||
// Read per-KV scales
|
||||
auto scale_kv = make_float2(ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset),
|
||||
ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset));
|
||||
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset);
|
||||
|
||||
// Wait UMMA arrival
|
||||
full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
|
||||
@@ -311,59 +311,49 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
|
||||
const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
|
||||
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
|
||||
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
||||
#pragma unroll
|
||||
constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q;
|
||||
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems");
|
||||
uint32_t shifted_accum[kNumLDTMElems];
|
||||
auto tmem_load = [&](auto... Is) {
|
||||
if constexpr (kNumLDTMElems == 32) {
|
||||
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
|
||||
} else if constexpr (kNumLDTMElems == 64) {
|
||||
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
|
||||
} else if constexpr (kNumLDTMElems == 128) {
|
||||
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
|
||||
}
|
||||
};
|
||||
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
empty_umma_barriers[warpgroup_idx]->arrive();
|
||||
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
// Load from the tensor memory
|
||||
constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128;
|
||||
uint32_t shifted_accum[kNumLDTMElems];
|
||||
DG_STATIC_ASSERT(kNumLDTMElems == 16 or kNumLDTMElems == 32 or kNumLDTMElems == 64, "Invalid LDTM");
|
||||
auto tmem_load = [&](auto... Is) {
|
||||
if constexpr (kNumLDTMElems == 16) {
|
||||
cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
|
||||
} else if constexpr (kNumLDTMElems == 32) {
|
||||
cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
|
||||
} else if constexpr (kNumLDTMElems == 64) {
|
||||
cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
|
||||
}
|
||||
};
|
||||
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
float* accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
|
||||
|
||||
// Release UMMA empty
|
||||
if (i == BLOCK_Q - 1)
|
||||
empty_umma_barriers[warpgroup_idx]->arrive();
|
||||
auto sum = make_float2(0, 0);
|
||||
|
||||
// Transform
|
||||
const auto& transform_2 = [&](const uint32_t& j, const uint32_t& k, const float2& sum) {
|
||||
auto a = make_float2(fmaxf(*reinterpret_cast<float*>(&shifted_accum[j * 4 + k]), 0),
|
||||
fmaxf(*reinterpret_cast<float*>(&shifted_accum[j * 4 + k + 2]), 0));
|
||||
auto b = make_float2(weights[i][j * 2 + k], weights[i][j * 2 + k]);
|
||||
return __ffma2_rn(a, b, sum);
|
||||
};
|
||||
|
||||
// Intra-thread reduction
|
||||
auto sum_0 = make_float2(0, 0);
|
||||
auto sum_1 = make_float2(0, 0);
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads / 8; ++ j) {
|
||||
sum_0 = transform_2(j, 0, sum_0);
|
||||
sum_1 = transform_2(j, 1, sum_1);
|
||||
for (int j = 0; j < kNumWeightsInReg; j += 2) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0),
|
||||
fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
sum = __ffma2_rn(a, b, sum);
|
||||
}
|
||||
auto v = __fmul2_rn(__fadd2_rn(sum_0, sum_1), scale_kv);
|
||||
|
||||
// Inter-thread reduction
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++ j) {
|
||||
const auto& offset = 1u << j;
|
||||
v.x += __shfl_xor_sync(0xffffffffu, v.x, offset);
|
||||
v.y += __shfl_xor_sync(0xffffffffu, v.y, offset);
|
||||
for (int j = kNumWeightsInReg; j < kNumHeads; j += 2) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
|
||||
ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
|
||||
sum = __ffma2_rn(a, b, sum);
|
||||
}
|
||||
|
||||
float result = sum.x + sum.y;
|
||||
result *= scale_kv;
|
||||
|
||||
// Store into the global memory
|
||||
// NOTES: we have redundant writes here, consider more carefully
|
||||
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
|
||||
logits[q_idx * stride_kv + kv_offset + v_0_offset] = v.x;
|
||||
logits[q_idx * stride_kv + kv_offset + v_1_offset] = v.y;
|
||||
logits[q_idx * stride_kv + kv_offset + v_offset] = result;
|
||||
}
|
||||
}
|
||||
num_total_kv_blocks += num_kv_blocks;
|
||||
|
||||
Reference in New Issue
Block a user