Fix a sync issue in SM100 MQA logits (#285)
This commit is contained in:
@@ -252,6 +252,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
||||
empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
|
||||
tcgen05_after_thread_sync();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
||||
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
||||
@@ -310,6 +311,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
|
||||
// Wait UMMA arrival
|
||||
full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Release KV empty
|
||||
empty_kv_barriers[kv_stage_idx]->arrive();
|
||||
@@ -334,6 +336,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
|
||||
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
tcgen05_before_thread_sync();
|
||||
empty_umma_barriers[warpgroup_idx]->arrive();
|
||||
|
||||
#pragma unroll
|
||||
|
||||
@@ -236,8 +236,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
uint32_t umma_phase = 1;
|
||||
|
||||
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
|
||||
if (q_idx != next_q_idx)
|
||||
if (q_idx != next_q_idx) {
|
||||
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
||||
full_q_barriers[q_stage_idx]->wait(q_phase);
|
||||
}
|
||||
|
||||
q_idx = next_q_idx;
|
||||
kv_idx = next_kv_idx;
|
||||
@@ -249,6 +251,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
|
||||
empty_umma_barriers[i]->wait(umma_phase);
|
||||
tcgen05_after_thread_sync();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
|
||||
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
|
||||
@@ -316,6 +319,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
|
||||
// Wait UMMA arrival
|
||||
full_umma_barriers[warpgroup_idx]->wait(umma_phase);
|
||||
tcgen05_after_thread_sync();
|
||||
umma_phase ^= 1;
|
||||
|
||||
// Release KV empty
|
||||
@@ -338,6 +342,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
|
||||
tcgen05_before_thread_sync();
|
||||
empty_umma_barriers[warpgroup_idx]->arrive();
|
||||
|
||||
#pragma unroll
|
||||
|
||||
Reference in New Issue
Block a user