From f97359fbfc5daf7fe841712ddb1267bb87a5d2ca Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 17:07:24 +0000 Subject: [PATCH] P6: TMA store uses mbarrier completion (same as load) TMA store: cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes Uses mbarrier for completion, not bulk_group. Restored sMbarStore to SMEM. --- dsv4/kernels/attention/fmha_6warp_multihead.cuh | 12 +++++++++--- .../attention/fmha_6warp_tma_multirow_multitile.cuh | 12 ++++++++---- dsv4/kernels/attention/fmha_multihead_capi.cu | 3 ++- dsv4/kernels/attention/fmha_multitile_capi.cu | 1 + dsv4/kernels/attention/fmha_tma.cuh | 12 ++++++++---- 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/dsv4/kernels/attention/fmha_6warp_multihead.cuh b/dsv4/kernels/attention/fmha_6warp_multihead.cuh index 713cec12..c0433bc1 100644 --- a/dsv4/kernels/attention/fmha_6warp_multihead.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multihead.cuh @@ -148,6 +148,8 @@ fmha_6warp_multihead_kernel(FmhaParams params) { float* s_p_vals = (float*)(sV + V_SUB_SZ); // Epilogue SMEM: row-major O tile for TMA store bf16_t* sO_epi = (bf16_t*)(((uintptr_t)(s_p_vals + SK_TILE) + 127) & ~(uintptr_t)127); + // TMA store mbarrier (16 bytes, 128B aligned) + uint64_t* sMbarStore = (uint64_t*)(((uintptr_t)(sO_epi + HD) + 127) & ~(uintptr_t)127); // ================================================================ // TMEM allocation (warp 4) @@ -329,16 +331,20 @@ fmha_6warp_multihead_kernel(FmhaParams params) { // so coords are relative to that head's start. if (tid == 0) { uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi); + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore); CUtensorMap* my_tma = params.tma_o + batch_idx * gridDim.y + head_idx; uint64_t tma_desc = (uint64_t)my_tma; - tma_store_2d(smem_addr, tma_desc, 0, 0); - tma_store_commit_group(); + tma_mbarrier_init(mbar_addr, 1); + asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); + tma_store_2d(smem_addr, tma_desc, mbar_addr, 0, 0); + tma_mbarrier_arrive_expect_tx(mbar_addr, HD * sizeof(bf16_t)); } __syncthreads(); // Wait for TMA store completion if (tid == 0) { - tma_store_wait_all(); + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore); + tma_mbarrier_wait(mbar_addr, 0); } __syncthreads(); } else { diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh index 6516c82f..a20c414c 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh @@ -108,6 +108,7 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) // P6: Row-major O epilogue buffer + TMA store mbarrier off = (off + 127) & ~(size_t)127; bf16_t* sO_epi_rowmajor = (bf16_t*)(sbuf + off); off += MAX_ROWS * HD_CHUNK * sizeof(bf16_t); + uint64_t* sMbarStore = (uint64_t*)(sbuf + off); off += 16; // Init TMEM + mbarrier (once, shared across hd_chunks) if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); @@ -360,15 +361,18 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) if (params.tma_o) { if (tid == 0) { uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi_rowmajor); + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore); CUtensorMap* my_tma = params.tma_o + batch_idx * params.n_h + head_idx; uint64_t tma_desc = (uint64_t)my_tma; - // TMA coords: x = hd_chunk_start, y = 0 - tma_store_2d(smem_addr, tma_desc, hd_chunk_start, 0); - tma_store_commit_group(); + tma_mbarrier_init(mbar_addr, 1); + asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); + tma_store_2d(smem_addr, tma_desc, mbar_addr, hd_chunk_start, 0); + tma_mbarrier_arrive_expect_tx(mbar_addr, T * HD_CHUNK * sizeof(bf16_t)); } __syncthreads(); if (tid == 0) { - tma_store_wait_all(); + uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore); + tma_mbarrier_wait(mbar_addr, 0); } __syncthreads(); } else { diff --git a/dsv4/kernels/attention/fmha_multihead_capi.cu b/dsv4/kernels/attention/fmha_multihead_capi.cu index 05569730..47bcb442 100644 --- a/dsv4/kernels/attention/fmha_multihead_capi.cu +++ b/dsv4/kernels/attention/fmha_multihead_capi.cu @@ -39,8 +39,9 @@ int fmha_compute_smem(int hd) { int sV = V_SUB_SZ * 2; // 512 int sp = SK * 4; // 512 int sO = hd * 2; // row-major O (HD BF16) + int sMbar = 16; // TMA store mbarrier // With 128B alignment between sections - int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + 256; + int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + (sMbar + 127) + 256; // Round up to 128B return (total + 127) & ~127; } diff --git a/dsv4/kernels/attention/fmha_multitile_capi.cu b/dsv4/kernels/attention/fmha_multitile_capi.cu index e306ffa1..676f7655 100644 --- a/dsv4/kernels/attention/fmha_multitile_capi.cu +++ b/dsv4/kernels/attention/fmha_multitile_capi.cu @@ -109,6 +109,7 @@ int fmha_multitile_decode_launch( // P6: sO_epi_rowmajor + sMbarStore off = (off+127)&~(size_t)127; off += 128 * hc * 2; // sO_epi_rowmajor (MAX_ROWS * HD_CHUNK BF16) + off += 16; // sMbarStore off += 256; // slack int smem = (int)((off + 127) & ~(size_t)127); diff --git a/dsv4/kernels/attention/fmha_tma.cuh b/dsv4/kernels/attention/fmha_tma.cuh index cdd4da77..9eca2b22 100644 --- a/dsv4/kernels/attention/fmha_tma.cuh +++ b/dsv4/kernels/attention/fmha_tma.cuh @@ -210,17 +210,21 @@ struct FmhaTmaDescriptors { __device__ __forceinline__ void tma_store_2d( uint32_t smem_src, uint64_t tma_desc, + uint32_t smem_mbar, int coord_x, int coord_y ) { - // cp.async.bulk.tensor.2d.global.shared::cluster.bulk_group [tensorMap, {coord}], [srcMem]; + // cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes + // For store: dest=global, src=shared::cluster + // mbarrier on SMEM side tracks completion asm volatile( - "cp.async.bulk.tensor.2d.global.shared::cluster.bulk_group " - "[%1, {%2, %3}], [%0];" + "cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes " + "[%1, {%2, %3}], [%0], [%4];" :: "r"(smem_src), "l"(tma_desc), "r"(coord_x), - "r"(coord_y) + "r"(coord_y), + "r"(smem_mbar) : "memory" ); }