P6: Fix TMA store — use bulk_group (commit+wait) not mbarrier
TMA store uses cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group NOT mbarrier::complete_tx::bytes. Completion tracked via: - cp.async.bulk.commit_group (after issuing stores) - cp.async.bulk.wait_group.read 0 (wait for all groups) Removed sMbarStore from SMEM allocations (no longer needed).
This commit is contained in:
@@ -148,8 +148,6 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
float* s_p_vals = (float*)(sV + V_SUB_SZ);
|
||||
// Epilogue SMEM: row-major O tile for TMA store
|
||||
bf16_t* sO_epi = (bf16_t*)(((uintptr_t)(s_p_vals + SK_TILE) + 127) & ~(uintptr_t)127);
|
||||
// TMA store mbarrier (16 bytes, 128B aligned)
|
||||
uint64_t* sMbarStore = (uint64_t*)(((uintptr_t)(sO_epi + HD) + 127) & ~(uintptr_t)127);
|
||||
|
||||
// ================================================================
|
||||
// TMEM allocation (warp 4)
|
||||
@@ -320,37 +318,27 @@ fmha_6warp_multihead_kernel(FmhaParams params) {
|
||||
// Step 4: TMA store SMEM → GMEM (or direct GMEM write)
|
||||
if (params.tma_o) {
|
||||
// Proper TMA store path — async, enables multi-CTA grids.
|
||||
// One thread initializes the mbarrier and issues the TMA store.
|
||||
if (tid == 0) {
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
|
||||
tma_mbarrier_init(mbar_addr, 1);
|
||||
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Uses bulk_group completion (commit_group + wait_group), NOT mbarrier.
|
||||
//
|
||||
// The O tensor is [batch, n_h, T, HD]. Each head's output starts
|
||||
// at a different GMEM offset. The TMA descriptor covers the full
|
||||
// O tensor. We index by (head, batch) to find the right descriptor.
|
||||
// at a different GMEM offset. The TMA descriptor covers one head's
|
||||
// output. We index by (head, batch) to find the right descriptor.
|
||||
//
|
||||
// TMA coords for this head's row 0: (x=0, y=0)
|
||||
// The descriptor was created with the head's GMEM base pointer,
|
||||
// so coords are relative to that head's start.
|
||||
if (tid == 0) {
|
||||
uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi);
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
|
||||
// Get this (head, batch) TMA descriptor
|
||||
CUtensorMap* my_tma = params.tma_o + batch_idx * gridDim.y + head_idx;
|
||||
uint64_t tma_desc = (uint64_t)my_tma;
|
||||
tma_store_2d(smem_addr, tma_desc, mbar_addr, 0, 0);
|
||||
// TMA tile: (1, HD) BF16 = HD*2 bytes
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, HD * sizeof(bf16_t));
|
||||
tma_store_2d(smem_addr, tma_desc, 0, 0);
|
||||
tma_store_commit_group();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Wait for TMA store completion
|
||||
if (tid == 0) {
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
|
||||
tma_store_wait(mbar_addr, 0);
|
||||
tma_store_wait_all();
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
|
||||
@@ -108,7 +108,6 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
|
||||
// P6: Row-major O epilogue buffer + TMA store mbarrier
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sO_epi_rowmajor = (bf16_t*)(sbuf + off); off += MAX_ROWS * HD_CHUNK * sizeof(bf16_t);
|
||||
uint64_t* sMbarStore = (uint64_t*)(sbuf + off); off += 16;
|
||||
|
||||
// Init TMEM + mbarrier (once, shared across hd_chunks)
|
||||
if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
|
||||
@@ -359,26 +358,17 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
|
||||
|
||||
// TMA store sO_epi_rowmajor → GMEM
|
||||
if (params.tma_o) {
|
||||
if (tid == 0) {
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
|
||||
tma_mbarrier_init(mbar_addr, 1);
|
||||
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
uint32_t smem_addr = (uint32_t)__cvta_generic_to_shared(sO_epi_rowmajor);
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
|
||||
CUtensorMap* my_tma = params.tma_o + batch_idx * params.n_h + head_idx;
|
||||
uint64_t tma_desc = (uint64_t)my_tma;
|
||||
// TMA coords: x = hd_chunk_start, y = 0
|
||||
tma_store_2d(smem_addr, tma_desc, mbar_addr, hd_chunk_start, 0);
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, T * HD_CHUNK * sizeof(bf16_t));
|
||||
tma_store_2d(smem_addr, tma_desc, hd_chunk_start, 0);
|
||||
tma_store_commit_group();
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbarStore);
|
||||
tma_store_wait(mbar_addr, 0);
|
||||
tma_store_wait_all();
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
|
||||
@@ -39,9 +39,8 @@ int fmha_compute_smem(int hd) {
|
||||
int sV = V_SUB_SZ * 2; // 512
|
||||
int sp = SK * 4; // 512
|
||||
int sO = hd * 2; // row-major O (HD BF16)
|
||||
int sMbar = 16; // TMA store mbarrier
|
||||
// With 128B alignment between sections
|
||||
int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + (sMbar + 127) + 256;
|
||||
int total = base + sQ0 + sK0 + (sPk + 127) + (sV + 127) + sp + (sO + 127) + 256;
|
||||
// Round up to 128B
|
||||
return (total + 127) & ~127;
|
||||
}
|
||||
|
||||
@@ -109,7 +109,6 @@ int fmha_multitile_decode_launch(
|
||||
// P6: sO_epi_rowmajor + sMbarStore
|
||||
off = (off+127)&~(size_t)127;
|
||||
off += 128 * hc * 2; // sO_epi_rowmajor (MAX_ROWS * HD_CHUNK BF16)
|
||||
off += 16; // sMbarStore
|
||||
off += 256; // slack
|
||||
int smem = (int)((off + 127) & ~(size_t)127);
|
||||
|
||||
|
||||
@@ -196,25 +196,28 @@ struct FmhaTmaDescriptors {
|
||||
/**
|
||||
* Issue a 2D TMA async copy from SMEM to GMEM (store).
|
||||
*
|
||||
* PTX syntax: cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group
|
||||
* [tensorMap, tensorCoords], [srcMem];
|
||||
*
|
||||
* IMPORTANT: TMA store uses bulk_group (not mbarrier) for completion tracking.
|
||||
* Completion is tracked via cp.async.bulk.commit_group + cp.async.bulk.wait_group.
|
||||
*
|
||||
* @param smem_src SMEM source address (via __cvta_generic_to_shared)
|
||||
* @param tma_desc Pointer to CUtensorMap in device memory (uint64_t cast)
|
||||
* @param smem_mbar SMEM mbarrier address (via __cvta_generic_to_shared)
|
||||
* @param coord_x Column coordinate (innermost dimension)
|
||||
* @param coord_y Row coordinate (outermost dimension)
|
||||
*/
|
||||
__device__ __forceinline__ void tma_store_2d(
|
||||
uint32_t smem_src,
|
||||
uint64_t tma_desc,
|
||||
uint32_t smem_mbar,
|
||||
int coord_x,
|
||||
int coord_y
|
||||
) {
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.2d.global.shared::cluster.mbarrier::complete_tx::bytes "
|
||||
"[%0, {%3, %4}], [%1], [%2];"
|
||||
:: "l"(tma_desc),
|
||||
"r"(smem_src),
|
||||
"r"(smem_mbar),
|
||||
"cp.async.bulk.tensor.2d.global.shared::cta.tile.bulk_group "
|
||||
"[%1, {%2, %3}], [%0];"
|
||||
:: "r"(smem_src),
|
||||
"l"(tma_desc),
|
||||
"r"(coord_x),
|
||||
"r"(coord_y)
|
||||
: "memory"
|
||||
@@ -222,22 +225,26 @@ __device__ __forceinline__ void tma_store_2d(
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for TMA store completion using mbarrier try_wait.
|
||||
* Same pattern as tma_mbarrier_wait but for store mbarriers.
|
||||
* Commit all pending TMA store operations to a group.
|
||||
* Must be called after issuing TMA stores, before waiting.
|
||||
*/
|
||||
__device__ __forceinline__ void tma_store_wait(uint32_t smem_mbar, int phase) {
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred P1;\n\t"
|
||||
"LAB_WAIT:"
|
||||
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n\t"
|
||||
"@P1 bra.uni DONE;\n\t"
|
||||
"bra.uni LAB_WAIT;\n\t"
|
||||
"DONE:\n\t"
|
||||
"}"
|
||||
:: "r"(smem_mbar), "r"(phase), "r"(0x989680)
|
||||
: "memory"
|
||||
);
|
||||
__device__ __forceinline__ void tma_store_commit_group() {
|
||||
asm volatile("cp.async.bulk.commit_group;" ::: "memory");
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for N most recent TMA store groups to complete.
|
||||
* N=0 waits for all groups. N=1 waits for the most recent group.
|
||||
*/
|
||||
__device__ __forceinline__ void tma_store_wait_group(int n) {
|
||||
asm volatile("cp.async.bulk.wait_group.read %0;" :: "r"(n) : "memory");
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for all pending TMA store groups to complete.
|
||||
*/
|
||||
__device__ __forceinline__ void tma_store_wait_all() {
|
||||
asm volatile("cp.async.bulk.wait_group.read 0;" ::: "memory");
|
||||
}
|
||||
|
||||
} // namespace dsv4::kernels::attention
|
||||
|
||||
Reference in New Issue
Block a user