diff --git a/dsv4/kernels/attention/fmha_6warp_tma.cuh b/dsv4/kernels/attention/fmha_6warp_tma.cuh index 4ea5b304..1993f276 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma.cuh @@ -149,7 +149,7 @@ fmha_6warp_tma_kernel(FmhaMultiRowTmaParams params) { // ================================================================== if (tid == 0) { uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); - tma_mbarrier_init(mbar_addr); + tma_mbarrier_init(mbar_addr, 128 * 16 * 2); } // TMEM alloc @@ -178,14 +178,14 @@ fmha_6warp_tma_kernel(FmhaMultiRowTmaParams params) { // Re-init mbarrier if (tid == 0) { uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); - tma_mbarrier_init(mbar_addr); + tma_mbarrier_init(mbar_addr, 128 * 16 * 2); } __syncthreads(); if (is_load_warp && lane == 0) { uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sQ_tma); uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); - tma_load_2d(smem_dst, (uint64_t)tma_q, mbar_addr, kt * MMA_K_BF16, 0); + tma_load_3d(smem_dst, (uint64_t)tma_q, mbar_addr, kt * MMA_K_BF16, 0); } // Wait for Q TMA completion @@ -208,14 +208,14 @@ fmha_6warp_tma_kernel(FmhaMultiRowTmaParams params) { // Re-init mbarrier if (tid == 0) { uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); - tma_mbarrier_init(mbar_addr); + tma_mbarrier_init(mbar_addr, 128 * 16 * 2); } __syncthreads(); if (is_load_warp && lane == 0) { uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sK_tma); uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); - tma_load_2d(smem_dst, (uint64_t)tma_k, mbar_addr, kt * MMA_K_BF16, 0); + tma_load_3d(smem_dst, (uint64_t)tma_k, mbar_addr, kt * MMA_K_BF16, 0); } if (is_load_warp && lane == 0) { @@ -333,14 +333,14 @@ fmha_6warp_tma_kernel(FmhaMultiRowTmaParams params) { // Re-init mbarrier if (tid == 0) { uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); - tma_mbarrier_init(mbar_addr); + tma_mbarrier_init(mbar_addr, 128 * 16 * 2); } __syncthreads(); if (is_load_warp && lane == 0) { uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sV_tma); uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); - tma_load_2d(smem_dst, (uint64_t)tma_v, mbar_addr, col_start, d_base); + tma_load_3d(smem_dst, (uint64_t)tma_v, mbar_addr, col_start, d_base); } if (is_load_warp && lane == 0) { diff --git a/dsv4/kernels/attention/fmha_tma.cuh b/dsv4/kernels/attention/fmha_tma.cuh index bc7bacbf..c48638d4 100644 --- a/dsv4/kernels/attention/fmha_tma.cuh +++ b/dsv4/kernels/attention/fmha_tma.cuh @@ -6,11 +6,11 @@ * ================================================================== * * Replaces scalar GMEM reads in the load warp with async TMA bulk - * copies via cp.async.bulk.tensor.2d. The pipeline: + * copies via cp.async.bulk.tensor.3d. The pipeline: * * Host: CUtensorMap creation for Q, K, V tiles * Kernel: - * 1. TMA warp issues cp.async.bulk.tensor.2d → SMEM (row-major) + * 1. TMA warp issues cp.async.bulk.tensor.3d → SMEM (row-major) * 2. mbarrier wait for TMA completion * 3. Load warp transposes row-major SMEM → canonical K-major SMEM * 4. MMA warp runs tcgen05.mma as before @@ -63,7 +63,7 @@ * TMA async copies use mbarrier for completion signaling: * * 1. Init mbarrier with expected transaction count = 1 - * 2. Issue cp.async.bulk.tensor.2d with the mbarrier + * 2. Issue cp.async.bulk.tensor.3d with the mbarrier * 3. Wait on mbarrier parity (spin or yield) * 4. After wait returns, SMEM data is ready * @@ -125,22 +125,27 @@ inline bool create_tma_desc_2d_bf16( uint32_t tile_cols, // TMA tile dimension 1 CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE ) { - // Global dimensions: (cols, rows) — TMA uses innermost-first ordering - uint64_t global_dim[] = {cols, rows}; - // Global strides: (1, cols) — element stride in each dimension - uint64_t global_str[] = {1, cols}; - // Tile dimensions: (tile_cols, tile_rows) — innermost-first - uint32_t tile_dim[] = {tile_cols, tile_rows}; - // Tile strides: (1, cols) — MUST match global strides, not tile dimensions - // The tile elements are accessed with the same stride as the global tensor. - // Moving from row r to r+1 within the tile means jumping 'cols' elements - // in the global tensor, regardless of tile width. - uint32_t tile_str[] = {1, (uint32_t)cols}; + // CUDA 13: globalStrides are in BYTES, not elements! + // globalStrides[0] = globalDim[0] * elementSizeInBytes = cols * 2 (BF16) + // globalStrides[1] = globalStrides[0] * globalDim[1] = cols * 2 * rows + // But for rank=2, only 1 stride is needed (rank-1) + // + // For rank=3 (recommended for CUDA 13), 2 strides are needed. + // We use rank=3 with degenerate 3rd dimension = 1. + // This avoids rank=2 edge cases and matches CUTLASS convention. + // + // 3D: (cols, rows, 1) innermost-first + uint64_t global_dim[] = {cols, rows, 1}; + uint64_t global_str[] = {cols * 2, cols * 2 * rows}; // byte strides + // Tile: (tile_cols, tile_rows, 1) innermost-first + uint32_t tile_dim[] = {tile_cols, tile_rows, 1}; + // Element strides within tile: 1 for each dim + uint32_t tile_str[] = {1, 1, 1}; CUresult res = cuTensorMapEncodeTiled( out, - CU_TENSOR_MAP_DATA_TYPE_UINT16, // BF16 = 2 bytes = UINT16 - 2, // 2D tensor + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, // Proper BF16 type + 3, // 3D tensor (degenerate 3rd dim) const_cast(gmem_ptr), global_dim, global_str, tile_dim, tile_str, CU_TENSOR_MAP_INTERLEAVE_NONE, @@ -149,9 +154,9 @@ inline bool create_tma_desc_2d_bf16( CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE ); if (res != CUDA_SUCCESS) { - fprintf(stderr, "cuTensorMapEncodeTiled failed: error=%d, gdim=[%lu,%lu], gstr=[%lu,%lu], tdim=[%u,%u], tstr=[%u,%u]\n", - (int)res, global_dim[0], global_dim[1], global_str[0], global_str[1], - tile_dim[0], tile_dim[1], tile_str[0], tile_str[1]); + fprintf(stderr, "cuTensorMapEncodeTiled failed: error=%d, gdim=[%lu,%lu,%lu], gstr=[%lu,%lu], tdim=[%u,%u,%u], tstr=[%u,%u,%u]\n", + (int)res, global_dim[0], global_dim[1], global_dim[2], global_str[0], global_str[1], + tile_dim[0], tile_dim[1], tile_dim[2], tile_str[0], tile_str[1], tile_str[2]); } return res == CUDA_SUCCESS; } @@ -161,16 +166,18 @@ inline bool create_tma_desc_2d_bf16( // ================================================================== /** - * Initialize an mbarrier in SMEM with expected transaction count = 1. + * Initialize an mbarrier in SMEM with expected byte count. * Only one thread should call this. + * For TMA with complete_tx::bytes, the expected count is the number + * of bytes that will be transferred. */ -__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar) { +__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar, uint32_t expected_bytes) { asm volatile("mbarrier.init.shared.b64 [%0], %1;" - :: "r"(smem_mbar), "r"(1)); + :: "r"(smem_mbar), "r"(expected_bytes)); } /** - * Issue a 2D TMA async copy from GMEM to SMEM. + * Issue a 3D TMA async copy from GMEM to SMEM. * * The TMA descriptor must be in device memory (GMEM). * Only ONE thread per CTA should issue the TMA copy. @@ -181,24 +188,27 @@ __device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar) { * @param smem_dst SMEM destination address (via __cvta_generic_to_shared) * @param tma_desc Pointer to CUtensorMap in device memory (uint64_t cast) * @param smem_mbar SMEM mbarrier address (via __cvta_generic_to_shared) - * @param coord_x Column coordinate (innermost dimension) - * @param coord_y Row coordinate (outermost dimension) + * @param coord_x Innermost dimension coordinate (columns) + * @param coord_y Middle dimension coordinate (rows) + * @param coord_z Outermost dimension coordinate (degenerate = 0) */ -__device__ __forceinline__ void tma_load_2d( +__device__ __forceinline__ void tma_load_3d( uint32_t smem_dst, uint64_t tma_desc, uint32_t smem_mbar, int coord_x, - int coord_y + int coord_y, + int coord_z = 0 ) { asm volatile( - "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes " - "[%0], [%1, {%3, %4}], [%2];" + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%3, %4, %5}], [%2];" :: "r"(smem_dst), "l"(tma_desc), "r"(smem_mbar), "r"(coord_x), - "r"(coord_y) + "r"(coord_y), + "r"(coord_z) : "memory" ); }