diff --git a/dsv4/kernels/attention/fmha_6warp_multihead.cuh b/dsv4/kernels/attention/fmha_6warp_multihead.cuh index 11543235..713cec12 100644 --- a/dsv4/kernels/attention/fmha_6warp_multihead.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multihead.cuh @@ -148,8 +148,6 @@ fmha_6warp_multihead_kernel(FmhaParams params) { float* s_p_vals = (float*)(sV + V_SUB_SZ); // Epilogue SMEM: row-major O tile for TMA store bf16_t* sO_epi = (bf16_t*)(((uintptr_t)(s_p_vals + SK_TILE) + 127) & ~(uintptr_t)127); - // TMA store mbarrier (16 bytes, 128B aligned) - uint64_t* sMbarStore = (uint64_t*)(((uintptr_t)(sO_epi + HD) + 127) & ~(uintptr_t)127); // ================================================================ // TMEM allocation (warp 4) @@ -320,37 +318,27 @@ fmha_6warp_multihead_kernel(FmhaParams params) { // Step 4: TMA store SMEM → GMEM (or direct GMEM write) if (params.tma_o) { // Proper TMA store path — async, enables multi-CTA grids. - // One thread initializes the mbarrier and issues the TMA store. - if (tid == 0) { - uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore); - tma_mbarrier_init(mbar_addr, 1); - asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); - } - __syncthreads(); - + // Uses bulk_group completion (commit_group + wait_group), NOT mbarrier. + // // The O tensor is [batch, n_h, T, HD]. Each head's output starts - // at a different GMEM offset. The TMA descriptor covers the full - // O tensor. We index by (head, batch) to find the right descriptor. + // at a different GMEM offset. The TMA descriptor covers one head's + // output. We index by (head, batch) to find the right descriptor. // // TMA coords for this head's row 0: (x=0, y=0) // The descriptor was created with the head's GMEM base pointer, // so coords are relative to that head's start. if (tid == 0) { uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi); - uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore); - // Get this (head, batch) TMA descriptor CUtensorMap* my_tma = params.tma_o + batch_idx * gridDim.y + head_idx; uint64_t tma_desc = (uint64_t)my_tma; - tma_store_2d(smem_addr, tma_desc, mbar_addr, 0, 0); - // TMA tile: (1, HD) BF16 = HD*2 bytes - tma_mbarrier_arrive_expect_tx(mbar_addr, HD * sizeof(bf16_t)); + tma_store_2d(smem_addr, tma_desc, 0, 0); + tma_store_commit_group(); } __syncthreads(); // Wait for TMA store completion if (tid == 0) { - uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore); - tma_store_wait(mbar_addr, 0); + tma_store_wait_all(); } __syncthreads(); } else { diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh index b4afe40f..6516c82f 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh @@ -108,7 +108,6 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) // P6: Row-major O epilogue buffer + TMA store mbarrier off = (off + 127) & ~(size_t)127; bf16_t* sO_epi_rowmajor = (bf16_t*)(sbuf + off); off += MAX_ROWS * HD_CHUNK * sizeof(bf16_t); - uint64_t* sMbarStore = (uint64_t*)(sbuf + off); off += 16; // Init TMEM + mbarrier (once, shared across hd_chunks) if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); @@ -359,26 +358,17 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) // TMA store sO_epi_rowmajor → GMEM if (params.tma_o) { - if (tid == 0) { - uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore); - tma_mbarrier_init(mbar_addr, 1); - asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); - } - __syncthreads(); - if (tid == 0) { uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi_rowmajor); - uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore); CUtensorMap* my_tma = params.tma_o + batch_idx * params.n_h + head_idx; uint64_t tma_desc = (uint64_t)my_tma; // TMA coords: x = hd_chunk_start, y = 0 - tma_store_2d(smem_addr, tma_desc, mbar_addr, hd_chunk_start, 0); - tma_mbarrier_arrive_expect_tx(mbar_addr, T * HD_CHUNK * sizeof(bf16_t)); + tma_store_2d(smem_addr, tma_desc, hd_chunk_start, 0); + tma_store_commit_group(); } __syncthreads(); if (tid == 0) { - uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore); - tma_store_wait(mbar_addr, 0); + tma_store_wait_all(); } __syncthreads(); } else { diff --git a/dsv4/kernels/attention/fmha_multihead_capi.cu b/dsv4/kernels/attention/fmha_multihead_capi.cu index 47bcb442..05569730 100644 --- a/dsv4/kernels/attention/fmha_multihead_capi.cu +++ b/dsv4/kernels/attention/fmha_multihead_capi.cu @@ -39,9 +39,8 @@ int fmha_compute_smem(int hd) { int sV = V_SUB_SZ * 2; // 512 int sp = SK * 4; // 512 int sO = hd * 2; // row-major O (HD BF16) - int sMbar = 16; // TMA store mbarrier // With 128B alignment between sections - int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + (sMbar + 127) + 256; + int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + 256; // Round up to 128B return (total + 127) & ~127; } diff --git a/dsv4/kernels/attention/fmha_multitile_capi.cu b/dsv4/kernels/attention/fmha_multitile_capi.cu index 676f7655..e306ffa1 100644 --- a/dsv4/kernels/attention/fmha_multitile_capi.cu +++ b/dsv4/kernels/attention/fmha_multitile_capi.cu @@ -109,7 +109,6 @@ int fmha_multitile_decode_launch( // P6: sO_epi_rowmajor + sMbarStore off = (off+127)&~(size_t)127; off += 128 * hc * 2; // sO_epi_rowmajor (MAX_ROWS * HD_CHUNK BF16) - off += 16; // sMbarStore off += 256; // slack int smem = (int)((off + 127) & ~(size_t)127); diff --git a/dsv4/kernels/attention/fmha_tma.cuh b/dsv4/kernels/attention/fmha_tma.cuh index d83aa89d..362ca91a 100644 --- a/dsv4/kernels/attention/fmha_tma.cuh +++ b/dsv4/kernels/attention/fmha_tma.cuh @@ -196,25 +196,28 @@ struct FmhaTmaDescriptors { /** * Issue a 2D TMA async copy from SMEM to GMEM (store). * + * PTX syntax: cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group + * [tensorMap, tensorCoords], [srcMem]; + * + * IMPORTANT: TMA store uses bulk_group (not mbarrier) for completion tracking. + * Completion is tracked via cp.async.bulk.commit_group + cp.async.bulk.wait_group. + * * @param smem_src SMEM source address (via __cvta_generic_to_shared) * @param tma_desc Pointer to CUtensorMap in device memory (uint64_t cast) - * @param smem_mbar SMEM mbarrier address (via __cvta_generic_to_shared) * @param coord_x Column coordinate (innermost dimension) * @param coord_y Row coordinate (outermost dimension) */ __device__ __forceinline__ void tma_store_2d( uint32_t smem_src, uint64_t tma_desc, - uint32_t smem_mbar, int coord_x, int coord_y ) { asm volatile( - "cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes " - "[%0, {%3, %4}], [%1], [%2];" - :: "l"(tma_desc), - "r"(smem_src), - "r"(smem_mbar), + "cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group " + "[%1, {%2, %3}], [%0];" + :: "r"(smem_src), + "l"(tma_desc), "r"(coord_x), "r"(coord_y) : "memory" @@ -222,22 +225,26 @@ __device__ __forceinline__ void tma_store_2d( } /** - * Wait for TMA store completion using mbarrier try_wait. - * Same pattern as tma_mbarrier_wait but for store mbarriers. + * Commit all pending TMA store operations to a group. + * Must be called after issuing TMA stores, before waiting. */ -__device__ __forceinline__ void tma_store_wait(uint32_t smem_mbar, int phase) { - asm volatile( - "{\n\t" - ".reg .pred P1;\n\t" - "LAB_WAIT:" - "mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n\t" - "@P1 bra.uni DONE;\n\t" - "bra.uni LAB_WAIT;\n\t" - "DONE:\n\t" - "}" - :: "r"(smem_mbar), "r"(phase), "r"(0x989680) - : "memory" - ); +__device__ __forceinline__ void tma_store_commit_group() { + asm volatile("cp.async.bulk.commit_group;" ::: "memory"); +} + +/** + * Wait for N most recent TMA store groups to complete. + * N=0 waits for all groups. N=1 waits for the most recent group. + */ +__device__ __forceinline__ void tma_store_wait_group(int n) { + asm volatile("cp.async.bulk.wait_group.read %0;" :: "r"(n) : "memory"); +} + +/** + * Wait for all pending TMA store groups to complete. + */ +__device__ __forceinline__ void tma_store_wait_all() { + asm volatile("cp.async.bulk.wait_group.read 0;" ::: "memory"); } } // namespace dsv4::kernels::attention