Fix sync issue of TMEM alloc/dealloc (#292)
This commit is contained in:
@@ -132,6 +132,9 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1);
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
if (kNumMulticast > 1)
|
||||
cute::cluster_sync();
|
||||
|
||||
// Initialize barriers
|
||||
if (warp_idx == 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
@@ -465,12 +468,13 @@ sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deallocate tensor memory by the last UMMA store warp
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
|
||||
Allocator().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
// Deallocate tensor memory
|
||||
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
||||
if (warp_idx == 0)
|
||||
Allocator().free(0, kNumTmemCols);
|
||||
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
||||
|
||||
@@ -251,6 +251,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// Deallocate tensor memory by warp 1
|
||||
// NOTES: warp 0 is doing TMA stores
|
||||
if (warp_idx == 1)
|
||||
|
||||
@@ -155,6 +155,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
if (kNumMulticast > 1)
|
||||
cute::cluster_sync();
|
||||
|
||||
// Initialize barriers
|
||||
if (warp_idx == 1 and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
@@ -546,12 +549,13 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deallocate tensor memory by the last UMMA store warp
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1)
|
||||
Allocator().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
// Deallocate tensor memory
|
||||
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
||||
if (warp_idx == 0)
|
||||
Allocator().free(0, kNumTmemCols);
|
||||
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
|
||||
|
||||
Reference in New Issue
Block a user