Fix a sync issue in SM100 MQA logits (#285)

This commit is contained in:
Ray Wang
2026-02-03 17:29:49 +08:00
committed by GitHub
parent 0f5f266202
commit 477618cd51
2 changed files with 9 additions and 1 deletions

View File

@@ -252,6 +252,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
tcgen05_after_thread_sync();
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
@@ -310,6 +311,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
// Wait UMMA arrival
full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
tcgen05_after_thread_sync();
// Release KV empty
empty_kv_barriers[kv_stage_idx]->arrive();
@@ -334,6 +336,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
cutlass::arch::fence_view_async_tmem_load();
tcgen05_before_thread_sync();
empty_umma_barriers[warpgroup_idx]->arrive();
#pragma unroll

View File

@@ -236,8 +236,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
uint32_t umma_phase = 1;
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
if (q_idx != next_q_idx)
if (q_idx != next_q_idx) {
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
full_q_barriers[q_stage_idx]->wait(q_phase);
}
q_idx = next_q_idx;
kv_idx = next_kv_idx;
@@ -249,6 +251,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
empty_umma_barriers[i]->wait(umma_phase);
tcgen05_after_thread_sync();
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
@@ -316,6 +319,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Wait UMMA arrival
full_umma_barriers[warpgroup_idx]->wait(umma_phase);
tcgen05_after_thread_sync();
umma_phase ^= 1;
// Release KV empty
@@ -338,6 +342,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
cutlass::arch::fence_view_async_tmem_load();
tcgen05_before_thread_sync();
empty_umma_barriers[warpgroup_idx]->arrive();
#pragma unroll