Fix sync issue of TMEM alloc/dealloc (#292)

This commit is contained in:
Ray Wang
2026-03-22 16:41:28 +08:00
committed by GitHub
parent 35c4bc8771
commit d30fc36c8f
3 changed files with 19 additions and 10 deletions

View File

@@ -132,6 +132,9 @@ sm100_bf16_gemm_impl(int* grouped_layout,
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
if (kNumMulticast > 1)
cute::cluster_sync();
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
@@ -465,12 +468,13 @@ sm100_bf16_gemm_impl(int* grouped_layout,
}
}
}
// Deallocate tensor memory by the last UMMA store warp
// NOTES: warp 0 is waiting TMA store
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
Allocator().free(0, kNumTmemCols);
}
// Deallocate tensor memory
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
if (warp_idx == 0)
Allocator().free(0, kNumTmemCols);
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");

View File

@@ -251,6 +251,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
}
}
__syncthreads();
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is doing TMA stores
if (warp_idx == 1)

View File

@@ -155,6 +155,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
if (kNumMulticast > 1)
cute::cluster_sync();
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
@@ -546,12 +549,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
}
}
}
// Deallocate tensor memory by the last UMMA store warp
// NOTES: warp 0 is waiting TMA store
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
Allocator().free(0, kNumTmemCols);
}
// Deallocate tensor memory
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
if (warp_idx == 0)
Allocator().free(0, kNumTmemCols);
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");