P6: Fix TMA store — use bulk_group (commit+wait) not mbarrier

TMA store uses cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group
NOT mbarrier::complete_tx::bytes. Completion tracked via:
  - cp.async.bulk.commit_group (after issuing stores)
  - cp.async.bulk.wait_group.read 0 (wait for all groups)

Removed sMbarStore from SMEM allocations (no longer needed).
This commit is contained in:
2026-05-30 16:57:35 +00:00
parent 212fc85627
commit fd7c0cb773
5 changed files with 40 additions and 57 deletions

View File

@@ -148,8 +148,6 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
float* s_p_vals = (float*)(sV + V_SUB_SZ);
// Epilogue SMEM: row-major O tile for TMA store
bf16_t* sO_epi = (bf16_t*)(((uintptr_t)(s_p_vals + SK_TILE) + 127) & ~(uintptr_t)127);
// TMA store mbarrier (16 bytes, 128B aligned)
uint64_t* sMbarStore = (uint64_t*)(((uintptr_t)(sO_epi + HD) + 127) & ~(uintptr_t)127);
// ================================================================
// TMEM allocation (warp 4)
@@ -320,37 +318,27 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
// Step 4: TMA store SMEM → GMEM (or direct GMEM write)
if (params.tma_o) {
// Proper TMA store path — async, enables multi-CTA grids.
// One thread initializes the mbarrier and issues the TMA store.
if (tid == 0) {
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
tma_mbarrier_init(mbar_addr, 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
// Uses bulk_group completion (commit_group + wait_group), NOT mbarrier.
//
// The O tensor is [batch, n_h, T, HD]. Each head's output starts
// at a different GMEM offset. The TMA descriptor covers the full
// O tensor. We index by (head, batch) to find the right descriptor.
// at a different GMEM offset. The TMA descriptor covers one head's
// output. We index by (head, batch) to find the right descriptor.
//
// TMA coords for this head's row 0: (x=0, y=0)
// The descriptor was created with the head's GMEM base pointer,
// so coords are relative to that head's start.
if (tid == 0) {
uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi);
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
// Get this (head, batch) TMA descriptor
CUtensorMap* my_tma = params.tma_o + batch_idx * gridDim.y + head_idx;
uint64_t tma_desc = (uint64_t)my_tma;
tma_store_2d(smem_addr, tma_desc, mbar_addr, 0, 0);
// TMA tile: (1, HD) BF16 = HD*2 bytes
tma_mbarrier_arrive_expect_tx(mbar_addr, HD * sizeof(bf16_t));
tma_store_2d(smem_addr, tma_desc, 0, 0);
tma_store_commit_group();
}
__syncthreads();
// Wait for TMA store completion
if (tid == 0) {
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
tma_store_wait(mbar_addr, 0);
tma_store_wait_all();
}
__syncthreads();
} else {

View File

@@ -108,7 +108,6 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
// P6: Row-major O epilogue buffer + TMA store mbarrier
off = (off + 127) & ~(size_t)127;
bf16_t* sO_epi_rowmajor = (bf16_t*)(sbuf + off); off += MAX_ROWS * HD_CHUNK * sizeof(bf16_t);
uint64_t* sMbarStore = (uint64_t*)(sbuf + off); off += 16;
// Init TMEM + mbarrier (once, shared across hd_chunks)
if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
@@ -359,26 +358,17 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
// TMA store sO_epi_rowmajor → GMEM
if (params.tma_o) {
if (tid == 0) {
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
tma_mbarrier_init(mbar_addr, 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
if (tid == 0) {
uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi_rowmajor);
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
CUtensorMap* my_tma = params.tma_o + batch_idx * params.n_h + head_idx;
uint64_t tma_desc = (uint64_t)my_tma;
// TMA coords: x = hd_chunk_start, y = 0
tma_store_2d(smem_addr, tma_desc, mbar_addr, hd_chunk_start, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, T * HD_CHUNK * sizeof(bf16_t));
tma_store_2d(smem_addr, tma_desc, hd_chunk_start, 0);
tma_store_commit_group();
}
__syncthreads();
if (tid == 0) {
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
tma_store_wait(mbar_addr, 0);
tma_store_wait_all();
}
__syncthreads();
} else {

View File

@@ -39,9 +39,8 @@ int fmha_compute_smem(int hd) {
int sV = V_SUB_SZ * 2; // 512
int sp = SK * 4; // 512
int sO = hd * 2; // row-major O (HD BF16)
int sMbar = 16; // TMA store mbarrier
// With 128B alignment between sections
int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + (sMbar + 127) + 256;
int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + 256;
// Round up to 128B
return (total + 127) & ~127;
}

View File

@@ -109,7 +109,6 @@ int fmha_multitile_decode_launch(
// P6: sO_epi_rowmajor + sMbarStore
off = (off+127)&~(size_t)127;
off += 128 * hc * 2; // sO_epi_rowmajor (MAX_ROWS * HD_CHUNK BF16)
off += 16; // sMbarStore
off += 256; // slack
int smem = (int)((off + 127) & ~(size_t)127);

View File

@@ -196,25 +196,28 @@ struct FmhaTmaDescriptors {
/**
* Issue a 2D TMA async copy from SMEM to GMEM (store).
*
* PTX syntax: cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group
* [tensorMap, tensorCoords], [srcMem];
*
* IMPORTANT: TMA store uses bulk_group (not mbarrier) for completion tracking.
* Completion is tracked via cp.async.bulk.commit_group + cp.async.bulk.wait_group.
*
* @param smem_src SMEM source address (via __cvta_generic_to_shared)
* @param tma_desc Pointer to CUtensorMap in device memory (uint64_t cast)
* @param smem_mbar SMEM mbarrier address (via __cvta_generic_to_shared)
* @param coord_x Column coordinate (innermost dimension)
* @param coord_y Row coordinate (outermost dimension)
*/
__device__ __forceinline__ void tma_store_2d(
uint32_t smem_src,
uint64_t tma_desc,
uint32_t smem_mbar,
int coord_x,
int coord_y
) {
asm volatile(
"cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes "
"[%0, {%3, %4}], [%1], [%2];"
:: "l"(tma_desc),
"r"(smem_src),
"r"(smem_mbar),
"cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group "
"[%1, {%2, %3}], [%0];"
:: "r"(smem_src),
"l"(tma_desc),
"r"(coord_x),
"r"(coord_y)
: "memory"
@@ -222,22 +225,26 @@ __device__ __forceinline__ void tma_store_2d(
}
/**
* Wait for TMA store completion using mbarrier try_wait.
* Same pattern as tma_mbarrier_wait but for store mbarriers.
* Commit all pending TMA store operations to a group.
* Must be called after issuing TMA stores, before waiting.
*/
__device__ __forceinline__ void tma_store_wait(uint32_t smem_mbar, int phase) {
asm volatile(
"{\n\t"
".reg .pred P1;\n\t"
"LAB_WAIT:"
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n\t"
"@P1 bra.uni DONE;\n\t"
"bra.uni LAB_WAIT;\n\t"
"DONE:\n\t"
"}"
:: "r"(smem_mbar), "r"(phase), "r"(0x989680)
: "memory"
);
__device__ __forceinline__ void tma_store_commit_group() {
asm volatile("cp.async.bulk.commit_group;" ::: "memory");
}
/**
* Wait for N most recent TMA store groups to complete.
* N=0 waits for all groups. N=1 waits for the most recent group.
*/
__device__ __forceinline__ void tma_store_wait_group(int n) {
asm volatile("cp.async.bulk.wait_group.read %0;" :: "r"(n) : "memory");
}
/**
* Wait for all pending TMA store groups to complete.
*/
__device__ __forceinline__ void tma_store_wait_all() {
asm volatile("cp.async.bulk.wait_group.read 0;" ::: "memory");
}
} // namespace dsv4::kernels::attention