P6: TMA store uses mbarrier completion (same as load)
TMA store: cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes Uses mbarrier for completion, not bulk_group. Restored sMbarStore to SMEM.
This commit is contained in:
@@ -148,6 +148,8 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
float* s_p_vals = (float*)(sV + V_SUB_SZ);
|
||||
// Epilogue SMEM: row-major O tile for TMA store
|
||||
bf16_t* sO_epi = (bf16_t*)(((uintptr_t)(s_p_vals + SK_TILE) + 127) & ~(uintptr_t)127);
|
||||
// TMA store mbarrier (16 bytes, 128B aligned)
|
||||
uint64_t* sMbarStore = (uint64_t*)(((uintptr_t)(sO_epi + HD) + 127) & ~(uintptr_t)127);
|
||||
|
||||
// ================================================================
|
||||
// TMEM allocation (warp 4)
|
||||
@@ -329,16 +331,20 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
// so coords are relative to that head's start.
|
||||
if (tid == 0) {
|
||||
uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi);
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
|
||||
CUtensorMap* my_tma = params.tma_o + batch_idx * gridDim.y + head_idx;
|
||||
uint64_t tma_desc = (uint64_t)my_tma;
|
||||
tma_store_2d(smem_addr, tma_desc, 0, 0);
|
||||
tma_store_commit_group();
|
||||
tma_mbarrier_init(mbar_addr, 1);
|
||||
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
|
||||
tma_store_2d(smem_addr, tma_desc, mbar_addr, 0, 0);
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, HD * sizeof(bf16_t));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Wait for TMA store completion
|
||||
if (tid == 0) {
|
||||
tma_store_wait_all();
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
|
||||
tma_mbarrier_wait(mbar_addr, 0);
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
|
||||
@@ -108,6 +108,7 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
|
||||
// P6: Row-major O epilogue buffer + TMA store mbarrier
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sO_epi_rowmajor = (bf16_t*)(sbuf + off); off += MAX_ROWS * HD_CHUNK * sizeof(bf16_t);
|
||||
uint64_t* sMbarStore = (uint64_t*)(sbuf + off); off += 16;
|
||||
|
||||
// Init TMEM + mbarrier (once, shared across hd_chunks)
|
||||
if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
|
||||
@@ -360,15 +361,18 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
|
||||
if (params.tma_o) {
|
||||
if (tid == 0) {
|
||||
uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi_rowmajor);
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
|
||||
CUtensorMap* my_tma = params.tma_o + batch_idx * params.n_h + head_idx;
|
||||
uint64_t tma_desc = (uint64_t)my_tma;
|
||||
// TMA coords: x = hd_chunk_start, y = 0
|
||||
tma_store_2d(smem_addr, tma_desc, hd_chunk_start, 0);
|
||||
tma_store_commit_group();
|
||||
tma_mbarrier_init(mbar_addr, 1);
|
||||
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
|
||||
tma_store_2d(smem_addr, tma_desc, mbar_addr, hd_chunk_start, 0);
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, T * HD_CHUNK * sizeof(bf16_t));
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
tma_store_wait_all();
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
|
||||
tma_mbarrier_wait(mbar_addr, 0);
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
|
||||
@@ -39,8 +39,9 @@ int fmha_compute_smem(int hd) {
|
||||
int sV = V_SUB_SZ * 2; // 512
|
||||
int sp = SK * 4; // 512
|
||||
int sO = hd * 2; // row-major O (HD BF16)
|
||||
int sMbar = 16; // TMA store mbarrier
|
||||
// With 128B alignment between sections
|
||||
int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + 256;
|
||||
int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + (sMbar + 127) + 256;
|
||||
// Round up to 128B
|
||||
return (total + 127) & ~127;
|
||||
}
|
||||
|
||||
@@ -109,6 +109,7 @@ int fmha_multitile_decode_launch(
|
||||
// P6: sO_epi_rowmajor + sMbarStore
|
||||
off = (off+127)&~(size_t)127;
|
||||
off += 128 * hc * 2; // sO_epi_rowmajor (MAX_ROWS * HD_CHUNK BF16)
|
||||
off += 16; // sMbarStore
|
||||
off += 256; // slack
|
||||
int smem = (int)((off + 127) & ~(size_t)127);
|
||||
|
||||
|
||||
@@ -210,17 +210,21 @@ struct FmhaTmaDescriptors {
|
||||
__device__ __forceinline__ void tma_store_2d(
|
||||
uint32_t smem_src,
|
||||
uint64_t tma_desc,
|
||||
uint32_t smem_mbar,
|
||||
int coord_x,
|
||||
int coord_y
|
||||
) {
|
||||
// cp.async.bulk.tensor.2d.global.shared::cluster.bulk_group [tensorMap, {coord}], [srcMem];
|
||||
// cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes
|
||||
// For store: dest=global, src=shared::cluster
|
||||
// mbarrier on SMEM side tracks completion
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.2d.global.shared::cluster.bulk_group "
|
||||
"[%1, {%2, %3}], [%0];"
|
||||
"cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes "
|
||||
"[%1, {%2, %3}], [%0], [%4];"
|
||||
:: "r"(smem_src),
|
||||
"l"(tma_desc),
|
||||
"r"(coord_x),
|
||||
"r"(coord_y)
|
||||
"r"(coord_y),
|
||||
"r"(smem_mbar)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user