P6: TMA store uses mbarrier completion (same as load)

TMA store: cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes
Uses mbarrier for completion, not bulk_group. Restored sMbarStore to SMEM.
This commit is contained in:
2026-05-30 17:07:24 +00:00
parent 2de300e281
commit f97359fbfc
5 changed files with 28 additions and 12 deletions

View File

@@ -148,6 +148,8 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
float* s_p_vals = (float*)(sV + V_SUB_SZ);
// Epilogue SMEM: row-major O tile for TMA store
bf16_t* sO_epi = (bf16_t*)(((uintptr_t)(s_p_vals + SK_TILE) + 127) & ~(uintptr_t)127);
// TMA store mbarrier (16 bytes, 128B aligned)
uint64_t* sMbarStore = (uint64_t*)(((uintptr_t)(sO_epi + HD) + 127) & ~(uintptr_t)127);
// ================================================================
// TMEM allocation (warp 4)
@@ -329,16 +331,20 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
// so coords are relative to that head's start.
if (tid == 0) {
uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi);
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
CUtensorMap* my_tma = params.tma_o + batch_idx * gridDim.y + head_idx;
uint64_t tma_desc = (uint64_t)my_tma;
tma_store_2d(smem_addr, tma_desc, 0, 0);
tma_store_commit_group();
tma_mbarrier_init(mbar_addr, 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
tma_store_2d(smem_addr, tma_desc, mbar_addr, 0, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, HD * sizeof(bf16_t));
}
__syncthreads();
// Wait for TMA store completion
if (tid == 0) {
tma_store_wait_all();
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
tma_mbarrier_wait(mbar_addr, 0);
}
__syncthreads();
} else {

View File

@@ -108,6 +108,7 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
// P6: Row-major O epilogue buffer + TMA store mbarrier
off = (off + 127) & ~(size_t)127;
bf16_t* sO_epi_rowmajor = (bf16_t*)(sbuf + off); off += MAX_ROWS * HD_CHUNK * sizeof(bf16_t);
uint64_t* sMbarStore = (uint64_t*)(sbuf + off); off += 16;
// Init TMEM + mbarrier (once, shared across hd_chunks)
if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
@@ -360,15 +361,18 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
if (params.tma_o) {
if (tid == 0) {
uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi_rowmajor);
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
CUtensorMap* my_tma = params.tma_o + batch_idx * params.n_h + head_idx;
uint64_t tma_desc = (uint64_t)my_tma;
// TMA coords: x = hd_chunk_start, y = 0
tma_store_2d(smem_addr, tma_desc, hd_chunk_start, 0);
tma_store_commit_group();
tma_mbarrier_init(mbar_addr, 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
tma_store_2d(smem_addr, tma_desc, mbar_addr, hd_chunk_start, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, T * HD_CHUNK * sizeof(bf16_t));
}
__syncthreads();
if (tid == 0) {
tma_store_wait_all();
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
tma_mbarrier_wait(mbar_addr, 0);
}
__syncthreads();
} else {

View File

@@ -39,8 +39,9 @@ int fmha_compute_smem(int hd) {
int sV = V_SUB_SZ * 2; // 512
int sp = SK * 4; // 512
int sO = hd * 2; // row-major O (HD BF16)
int sMbar = 16; // TMA store mbarrier
// With 128B alignment between sections
int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + 256;
int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + (sMbar + 127) + 256;
// Round up to 128B
return (total + 127) & ~127;
}

View File

@@ -109,6 +109,7 @@ int fmha_multitile_decode_launch(
// P6: sO_epi_rowmajor + sMbarStore
off = (off+127)&~(size_t)127;
off += 128 * hc * 2; // sO_epi_rowmajor (MAX_ROWS * HD_CHUNK BF16)
off += 16; // sMbarStore
off += 256; // slack
int smem = (int)((off + 127) & ~(size_t)127);

View File

@@ -210,17 +210,21 @@ struct FmhaTmaDescriptors {
__device__ __forceinline__ void tma_store_2d(
uint32_t smem_src,
uint64_t tma_desc,
uint32_t smem_mbar,
int coord_x,
int coord_y
) {
// cp.async.bulk.tensor.2d.global.shared::cluster.bulk_group [tensorMap, {coord}], [srcMem];
// cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes
// For store: dest=global, src=shared::cluster
// mbarrier on SMEM side tracks completion
asm volatile(
"cp.async.bulk.tensor.2d.global.shared::cluster.bulk_group "
"[%1, {%2, %3}], [%0];"
"cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes "
"[%1, {%2, %3}], [%0], [%4];"
:: "r"(smem_src),
"l"(tma_desc),
"r"(coord_x),
"r"(coord_y)
"r"(coord_y),
"r"(smem_mbar)
: "memory"
);
}