diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh index c51f360..180a308 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -252,6 +252,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); + tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { auto a_desc = make_umma_desc( @@ -310,6 +311,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // Wait UMMA arrival full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); + tcgen05_after_thread_sync(); // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); @@ -334,6 +336,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); cutlass::arch::fence_view_async_tmem_load(); + tcgen05_before_thread_sync(); empty_umma_barriers[warpgroup_idx]->arrive(); #pragma unroll diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh index 049ba74..7058c40 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -236,8 +236,10 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, uint32_t umma_phase = 1; while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { - if (q_idx != next_q_idx) + if (q_idx != next_q_idx) { CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + } q_idx = next_q_idx; kv_idx = next_kv_idx; @@ -249,6 +251,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { empty_umma_barriers[i]->wait(umma_phase); + tcgen05_after_thread_sync(); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { auto a_desc = make_umma_desc( @@ -316,6 +319,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Wait UMMA arrival full_umma_barriers[warpgroup_idx]->wait(umma_phase); + tcgen05_after_thread_sync(); umma_phase ^= 1; // Release KV empty @@ -338,6 +342,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); cutlass::arch::fence_view_async_tmem_load(); + tcgen05_before_thread_sync(); empty_umma_barriers[warpgroup_idx]->arrive(); #pragma unroll