fix: CUDA 13 TMA descriptor — 3D rank + byte strides + mbarrier byte count
Three critical fixes for CUDA 13.x on Blackwell: 1. globalStrides are in BYTES not elements (CUDA 13 change) 2. Use 3D descriptors (degenerate 3rd dim=1) — CUDA 13 TMA requires rank >= 2 3. mbarrier init uses expected byte count (4096 for 128x16 BF16 tile) 4. cp.async.bulk.tensor.3d instead of .2d for 3D descriptors 5. BFLOAT16 data type instead of UINT16
This commit is contained in:
@@ -149,7 +149,7 @@ fmha_6warp_tma_kernel(FmhaMultiRowTmaParams params) {
|
||||
// ==================================================================
|
||||
if (tid == 0) {
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
||||
tma_mbarrier_init(mbar_addr);
|
||||
tma_mbarrier_init(mbar_addr, 128 * 16 * 2);
|
||||
}
|
||||
|
||||
// TMEM alloc
|
||||
@@ -178,14 +178,14 @@ fmha_6warp_tma_kernel(FmhaMultiRowTmaParams params) {
|
||||
// Re-init mbarrier
|
||||
if (tid == 0) {
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
||||
tma_mbarrier_init(mbar_addr);
|
||||
tma_mbarrier_init(mbar_addr, 128 * 16 * 2);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_load_warp && lane == 0) {
|
||||
uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sQ_tma);
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
||||
tma_load_2d(smem_dst, (uint64_t)tma_q, mbar_addr, kt * MMA_K_BF16, 0);
|
||||
tma_load_3d(smem_dst, (uint64_t)tma_q, mbar_addr, kt * MMA_K_BF16, 0);
|
||||
}
|
||||
|
||||
// Wait for Q TMA completion
|
||||
@@ -208,14 +208,14 @@ fmha_6warp_tma_kernel(FmhaMultiRowTmaParams params) {
|
||||
// Re-init mbarrier
|
||||
if (tid == 0) {
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
||||
tma_mbarrier_init(mbar_addr);
|
||||
tma_mbarrier_init(mbar_addr, 128 * 16 * 2);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_load_warp && lane == 0) {
|
||||
uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sK_tma);
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
||||
tma_load_2d(smem_dst, (uint64_t)tma_k, mbar_addr, kt * MMA_K_BF16, 0);
|
||||
tma_load_3d(smem_dst, (uint64_t)tma_k, mbar_addr, kt * MMA_K_BF16, 0);
|
||||
}
|
||||
|
||||
if (is_load_warp && lane == 0) {
|
||||
@@ -333,14 +333,14 @@ fmha_6warp_tma_kernel(FmhaMultiRowTmaParams params) {
|
||||
// Re-init mbarrier
|
||||
if (tid == 0) {
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
||||
tma_mbarrier_init(mbar_addr);
|
||||
tma_mbarrier_init(mbar_addr, 128 * 16 * 2);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_load_warp && lane == 0) {
|
||||
uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sV_tma);
|
||||
uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
||||
tma_load_2d(smem_dst, (uint64_t)tma_v, mbar_addr, col_start, d_base);
|
||||
tma_load_3d(smem_dst, (uint64_t)tma_v, mbar_addr, col_start, d_base);
|
||||
}
|
||||
|
||||
if (is_load_warp && lane == 0) {
|
||||
|
||||
@@ -6,11 +6,11 @@
|
||||
* ==================================================================
|
||||
*
|
||||
* Replaces scalar GMEM reads in the load warp with async TMA bulk
|
||||
* copies via cp.async.bulk.tensor.2d. The pipeline:
|
||||
* copies via cp.async.bulk.tensor.3d. The pipeline:
|
||||
*
|
||||
* Host: CUtensorMap creation for Q, K, V tiles
|
||||
* Kernel:
|
||||
* 1. TMA warp issues cp.async.bulk.tensor.2d → SMEM (row-major)
|
||||
* 1. TMA warp issues cp.async.bulk.tensor.3d → SMEM (row-major)
|
||||
* 2. mbarrier wait for TMA completion
|
||||
* 3. Load warp transposes row-major SMEM → canonical K-major SMEM
|
||||
* 4. MMA warp runs tcgen05.mma as before
|
||||
@@ -63,7 +63,7 @@
|
||||
* TMA async copies use mbarrier for completion signaling:
|
||||
*
|
||||
* 1. Init mbarrier with expected transaction count = 1
|
||||
* 2. Issue cp.async.bulk.tensor.2d with the mbarrier
|
||||
* 2. Issue cp.async.bulk.tensor.3d with the mbarrier
|
||||
* 3. Wait on mbarrier parity (spin or yield)
|
||||
* 4. After wait returns, SMEM data is ready
|
||||
*
|
||||
@@ -125,22 +125,27 @@ inline bool create_tma_desc_2d_bf16(
|
||||
uint32_t tile_cols, // TMA tile dimension 1
|
||||
CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE
|
||||
) {
|
||||
// Global dimensions: (cols, rows) — TMA uses innermost-first ordering
|
||||
uint64_t global_dim[] = {cols, rows};
|
||||
// Global strides: (1, cols) — element stride in each dimension
|
||||
uint64_t global_str[] = {1, cols};
|
||||
// Tile dimensions: (tile_cols, tile_rows) — innermost-first
|
||||
uint32_t tile_dim[] = {tile_cols, tile_rows};
|
||||
// Tile strides: (1, cols) — MUST match global strides, not tile dimensions
|
||||
// The tile elements are accessed with the same stride as the global tensor.
|
||||
// Moving from row r to r+1 within the tile means jumping 'cols' elements
|
||||
// in the global tensor, regardless of tile width.
|
||||
uint32_t tile_str[] = {1, (uint32_t)cols};
|
||||
// CUDA 13: globalStrides are in BYTES, not elements!
|
||||
// globalStrides[0] = globalDim[0] * elementSizeInBytes = cols * 2 (BF16)
|
||||
// globalStrides[1] = globalStrides[0] * globalDim[1] = cols * 2 * rows
|
||||
// But for rank=2, only 1 stride is needed (rank-1)
|
||||
//
|
||||
// For rank=3 (recommended for CUDA 13), 2 strides are needed.
|
||||
// We use rank=3 with degenerate 3rd dimension = 1.
|
||||
// This avoids rank=2 edge cases and matches CUTLASS convention.
|
||||
//
|
||||
// 3D: (cols, rows, 1) innermost-first
|
||||
uint64_t global_dim[] = {cols, rows, 1};
|
||||
uint64_t global_str[] = {cols * 2, cols * 2 * rows}; // byte strides
|
||||
// Tile: (tile_cols, tile_rows, 1) innermost-first
|
||||
uint32_t tile_dim[] = {tile_cols, tile_rows, 1};
|
||||
// Element strides within tile: 1 for each dim
|
||||
uint32_t tile_str[] = {1, 1, 1};
|
||||
|
||||
CUresult res = cuTensorMapEncodeTiled(
|
||||
out,
|
||||
CU_TENSOR_MAP_DATA_TYPE_UINT16, // BF16 = 2 bytes = UINT16
|
||||
2, // 2D tensor
|
||||
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, // Proper BF16 type
|
||||
3, // 3D tensor (degenerate 3rd dim)
|
||||
const_cast<void*>(gmem_ptr),
|
||||
global_dim, global_str, tile_dim, tile_str,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
@@ -149,9 +154,9 @@ inline bool create_tma_desc_2d_bf16(
|
||||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
|
||||
);
|
||||
if (res != CUDA_SUCCESS) {
|
||||
fprintf(stderr, "cuTensorMapEncodeTiled failed: error=%d, gdim=[%lu,%lu], gstr=[%lu,%lu], tdim=[%u,%u], tstr=[%u,%u]\n",
|
||||
(int)res, global_dim[0], global_dim[1], global_str[0], global_str[1],
|
||||
tile_dim[0], tile_dim[1], tile_str[0], tile_str[1]);
|
||||
fprintf(stderr, "cuTensorMapEncodeTiled failed: error=%d, gdim=[%lu,%lu,%lu], gstr=[%lu,%lu], tdim=[%u,%u,%u], tstr=[%u,%u,%u]\n",
|
||||
(int)res, global_dim[0], global_dim[1], global_dim[2], global_str[0], global_str[1],
|
||||
tile_dim[0], tile_dim[1], tile_dim[2], tile_str[0], tile_str[1], tile_str[2]);
|
||||
}
|
||||
return res == CUDA_SUCCESS;
|
||||
}
|
||||
@@ -161,16 +166,18 @@ inline bool create_tma_desc_2d_bf16(
|
||||
// ==================================================================
|
||||
|
||||
/**
|
||||
* Initialize an mbarrier in SMEM with expected transaction count = 1.
|
||||
* Initialize an mbarrier in SMEM with expected byte count.
|
||||
* Only one thread should call this.
|
||||
* For TMA with complete_tx::bytes, the expected count is the number
|
||||
* of bytes that will be transferred.
|
||||
*/
|
||||
__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar) {
|
||||
__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar, uint32_t expected_bytes) {
|
||||
asm volatile("mbarrier.init.shared.b64 [%0], %1;"
|
||||
:: "r"(smem_mbar), "r"(1));
|
||||
:: "r"(smem_mbar), "r"(expected_bytes));
|
||||
}
|
||||
|
||||
/**
|
||||
* Issue a 2D TMA async copy from GMEM to SMEM.
|
||||
* Issue a 3D TMA async copy from GMEM to SMEM.
|
||||
*
|
||||
* The TMA descriptor must be in device memory (GMEM).
|
||||
* Only ONE thread per CTA should issue the TMA copy.
|
||||
@@ -181,24 +188,27 @@ __device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar) {
|
||||
* @param smem_dst SMEM destination address (via __cvta_generic_to_shared)
|
||||
* @param tma_desc Pointer to CUtensorMap in device memory (uint64_t cast)
|
||||
* @param smem_mbar SMEM mbarrier address (via __cvta_generic_to_shared)
|
||||
* @param coord_x Column coordinate (innermost dimension)
|
||||
* @param coord_y Row coordinate (outermost dimension)
|
||||
* @param coord_x Innermost dimension coordinate (columns)
|
||||
* @param coord_y Middle dimension coordinate (rows)
|
||||
* @param coord_z Outermost dimension coordinate (degenerate = 0)
|
||||
*/
|
||||
__device__ __forceinline__ void tma_load_2d(
|
||||
__device__ __forceinline__ void tma_load_3d(
|
||||
uint32_t smem_dst,
|
||||
uint64_t tma_desc,
|
||||
uint32_t smem_mbar,
|
||||
int coord_x,
|
||||
int coord_y
|
||||
int coord_y,
|
||||
int coord_z = 0
|
||||
) {
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes "
|
||||
"[%0], [%1, {%3, %4}], [%2];"
|
||||
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes "
|
||||
"[%0], [%1, {%3, %4, %5}], [%2];"
|
||||
:: "r"(smem_dst),
|
||||
"l"(tma_desc),
|
||||
"r"(smem_mbar),
|
||||
"r"(coord_x),
|
||||
"r"(coord_y)
|
||||
"r"(coord_y),
|
||||
"r"(coord_z)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user