diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh index 0227b3e..d9645b3 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -132,6 +132,9 @@ sm100_bf16_gemm_impl(int* grouped_layout, auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1); DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + if (kNumMulticast > 1) + cute::cluster_sync(); + // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { #pragma unroll @@ -465,12 +468,13 @@ sm100_bf16_gemm_impl(int* grouped_layout, } } } - - // Deallocate tensor memory by the last UMMA store warp - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) - Allocator().free(0, kNumTmemCols); } + + // Deallocate tensor memory + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh index 8630334..f5101b9 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -251,6 +251,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, } } + __syncthreads(); // Deallocate tensor memory by warp 1 // NOTES: warp 0 is doing TMA stores if (warp_idx == 1) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 45a603a..7ce008e 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -155,6 +155,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + if (kNumMulticast > 1) + cute::cluster_sync(); + // Initialize barriers if (warp_idx == 1 and cute::elect_one_sync()) { #pragma unroll @@ -546,12 +549,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, } } } - - // Deallocate tensor memory by the last UMMA store warp - // NOTES: warp 0 is waiting TMA store - if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) - Allocator().free(0, kNumTmemCols); } + + // Deallocate tensor memory + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + #else if (blockIdx.x == 0 and threadIdx.x == 0) DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");