- Deleted fmha.py (CuTeDSL slow path), FmhaKernel, Python KV merge - Deleted fmha_sm100.cuh, fmha_sm100_tc.cuh, fmha_sm100_launch.cu, fmha_epilogue_sm100.cuh - Moved fmha_qk_verify.cuh to tests/unit/qk_verify_kernel.cuh - Deleted decode_sparse.py, decode_swa.py, kernels/decode/ - Deleted 46 test_d*.py probes, test_smem_*, test_cotiled_*, test_tmem_*, test_smem_p_*, test_ultra_minimal, test_fmha_pv16, test_working_softmax_maybe - Deleted root scratch: debug_linear.py, test_mapping.py, run_router_tests.py - Moved archive/ to archived_plans/code_archive/ - Rewrote production.py: single fast path via 6-warp multi-tile kernel - Added STATUS.md, audit_attention_live.md - Moved NEXT_PRIORITIES*.md to archived_plans/
258 lines
9.9 KiB
Plaintext
258 lines
9.9 KiB
Plaintext
/**
|
||
* DSV4 FMHA — TMA async load infrastructure for Blackwell SM100.
|
||
*
|
||
* ==================================================================
|
||
* DESIGN
|
||
* ==================================================================
|
||
*
|
||
* Replaces scalar GMEM reads in the load warp with async TMA bulk
|
||
* copies via cp.async.bulk.tensor.3d. The pipeline:
|
||
*
|
||
* Host: CUtensorMap creation for Q, K, V tiles
|
||
* Kernel:
|
||
* 1. TMA warp issues cp.async.bulk.tensor.3d → SMEM (row-major)
|
||
* 2. mbarrier wait for TMA completion
|
||
* 3. Load warp transposes row-major SMEM → canonical K-major SMEM
|
||
* 4. MMA warp runs tcgen05.mma as before
|
||
*
|
||
* For double-buffered pipeline overlap (future):
|
||
* - Two SMEM buffers per tensor (sQ0/sQ1, sK0/sK1)
|
||
* - TMA load of K-tile (kt+1) overlaps with MMA on K-tile (kt)
|
||
* - Pipeline stages managed via mbarrier arrive/wait
|
||
*
|
||
* ==================================================================
|
||
* TMA DESCRIPTOR LAYOUT
|
||
* ==================================================================
|
||
*
|
||
* We create 2D CUtensorMap descriptors for each tile the kernel needs:
|
||
*
|
||
* Q tile: (T, HD) — one tile for the full Q
|
||
* K tile: (s_k, HD) — one tile for the full K (or (128, 16) per K-sub-tile)
|
||
* V tile: (HD, s_k) — transposed, one tile for the full V
|
||
*
|
||
* TMA copies data from GMEM to SMEM in row-major order. After TMA
|
||
* completion, the load warp transposes from row-major to the
|
||
* canonical K-major core-matrix layout that tcgen05.mma expects.
|
||
*
|
||
* For the multirow kernel, Q is (T, HD) and K is (s_k, HD).
|
||
* Since TMA operates on 2D tiles and our SMEM is (128, 16) per
|
||
* MMA K-tile, we have two choices:
|
||
*
|
||
* Option A: TMA load full (T, HD) → row-major SMEM → transpose
|
||
* - One TMA descriptor for Q, one for K
|
||
* - Larger SMEM footprint (need row-major + canonical)
|
||
* - Simpler descriptor management
|
||
*
|
||
* Option B: TMA load per (128, 16) K-sub-tile
|
||
* - One TMA descriptor, multiple TMA issues with different coords
|
||
* - Same SMEM as current (no double buffer needed for single-stage)
|
||
* - Matches the existing K-tiling loop structure
|
||
*
|
||
* We choose Option B: TMA per (128, 16) K-sub-tile. This:
|
||
* - Reuses the exact same SMEM layout as the current kernel
|
||
* - Fits the existing QK loop structure (kt = 0..NKT_QK-1)
|
||
* - Enables future pipeline overlap with minimal changes
|
||
* - The TMA descriptor covers the full (T, HD) or (s_k, HD) tensor,
|
||
* and we issue TMA loads for specific (col, row) coordinates
|
||
* targeting each 128×16 tile
|
||
*
|
||
* ==================================================================
|
||
* MBARRIER PROTOCOL
|
||
* ==================================================================
|
||
*
|
||
* TMA async copies use mbarrier for completion signaling:
|
||
*
|
||
* 1. Init mbarrier with expected transaction count = 1
|
||
* 2. Issue cp.async.bulk.tensor.3d with the mbarrier
|
||
* 3. Wait on mbarrier parity (spin or yield)
|
||
* 4. After wait returns, SMEM data is ready
|
||
*
|
||
* The mbarrier lives in SMEM. One mbarrier per outstanding TMA
|
||
* operation. For single-stage (no overlap), we use one mbarrier
|
||
* and wait immediately after issue.
|
||
*
|
||
* ==================================================================
|
||
* SWIZZLE CONSIDERATIONS
|
||
* ==================================================================
|
||
*
|
||
* TMA descriptors support SWIZZLE_NONE, SWIZZLE_32B, SWIZZLE_64B,
|
||
* SWIZZLE_128B. The swizzle pattern in SMEM matches what UMMA
|
||
* descriptors expect when using make_umma_desc_kmajor_sw128.
|
||
*
|
||
* Current kernel uses SWIZZLE_NONE (make_umma_desc_kmajor_none).
|
||
* With TMA, we have two paths:
|
||
*
|
||
* Path 1: TMA with SWIZZLE_NONE → SMEM is row-major → transpose to canonical
|
||
* Path 2: TMA with SWIZZLE_128B → SMEM is swizzled → UMMA reads directly
|
||
*
|
||
* Path 2 is the production target: no transpose needed, TMA writes
|
||
* in the exact layout MMA reads. But getting the swizzle right is
|
||
* tricky and needs careful verification.
|
||
*
|
||
* We start with Path 1 (SWIZZLE_NONE + transpose) to get TMA working,
|
||
* then upgrade to Path 2 (SWIZZLE_128B, zero-copy) for performance.
|
||
* ==================================================================
|
||
*/
|
||
|
||
#pragma once
|
||
|
||
#include "fmha_common.cuh"
|
||
#include <cstdint>
|
||
|
||
namespace dsv4::kernels::attention {
|
||
|
||
// ==================================================================
|
||
// TMA descriptor helpers (host-side)
|
||
// ==================================================================
|
||
// These are called from host code to create CUtensorMap objects
|
||
// that the kernel uses for TMA async copies.
|
||
// ==================================================================
|
||
|
||
/**
|
||
* Create a 2D TMA descriptor for a BF16 tensor of shape (rows, cols).
|
||
* The tensor is row-major in GMEM with stride = cols.
|
||
* TMA tile dimensions are (tile_rows, tile_cols).
|
||
*
|
||
* The descriptor is written to `out` (host memory).
|
||
* Must be copied to device memory before kernel launch.
|
||
*/
|
||
inline bool create_tma_desc_2d_bf16(
|
||
CUtensorMap* out,
|
||
const void* gmem_ptr, // device pointer to the BF16 tensor
|
||
uint64_t rows, // global dimension 0 (number of rows)
|
||
uint64_t cols, // global dimension 1 (number of columns)
|
||
uint32_t tile_rows, // TMA tile dimension 0
|
||
uint32_t tile_cols, // TMA tile dimension 1
|
||
CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE
|
||
) {
|
||
// CUDA 13: globalStrides are in BYTES, not elements!
|
||
// globalStrides[0] = globalDim[0] * elementSizeInBytes = cols * 2 (BF16)
|
||
// globalStrides[1] = globalStrides[0] * globalDim[1] = cols * 2 * rows
|
||
// But for rank=2, only 1 stride is needed (rank-1)
|
||
//
|
||
// For rank=3 (recommended for CUDA 13), 2 strides are needed.
|
||
// We use rank=3 with degenerate 3rd dimension = 1.
|
||
// This avoids rank=2 edge cases and matches CUTLASS convention.
|
||
//
|
||
// 3D: (cols, rows, 1) innermost-first
|
||
uint64_t global_dim[] = {cols, rows, 1};
|
||
uint64_t global_str[] = {cols * 2, cols * 2 * rows}; // byte strides
|
||
// Tile: (tile_cols, tile_rows, 1) innermost-first
|
||
uint32_t tile_dim[] = {tile_cols, tile_rows, 1};
|
||
// Element strides within tile: 1 for each dim
|
||
uint32_t tile_str[] = {1, 1, 1};
|
||
|
||
CUresult res = cuTensorMapEncodeTiled(
|
||
out,
|
||
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, // Proper BF16 type
|
||
3, // 3D tensor (degenerate 3rd dim)
|
||
const_cast<void*>(gmem_ptr),
|
||
global_dim, global_str, tile_dim, tile_str,
|
||
CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||
swizzle,
|
||
CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
|
||
);
|
||
if (res != CUDA_SUCCESS) {
|
||
fprintf(stderr, "cuTensorMapEncodeTiled failed: error=%d, gdim=[%lu,%lu,%lu], gstr=[%lu,%lu], tdim=[%u,%u,%u], tstr=[%u,%u,%u]\n",
|
||
(int)res, global_dim[0], global_dim[1], global_dim[2], global_str[0], global_str[1],
|
||
tile_dim[0], tile_dim[1], tile_dim[2], tile_str[0], tile_str[1], tile_str[2]);
|
||
}
|
||
return res == CUDA_SUCCESS;
|
||
}
|
||
|
||
// ==================================================================
|
||
// TMA kernel-side operations
|
||
// ==================================================================
|
||
|
||
/**
|
||
* Initialize an mbarrier in SMEM with expected byte count.
|
||
* Only one thread should call this.
|
||
* For TMA with complete_tx::bytes, the expected count is the number
|
||
* of bytes that will be transferred.
|
||
*/
|
||
__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar, uint32_t expected_bytes) {
|
||
asm volatile("mbarrier.init.shared.b64 [%0], %1;"
|
||
:: "r"(smem_mbar), "r"(expected_bytes));
|
||
}
|
||
|
||
/**
|
||
* Issue a 3D TMA async copy from GMEM to SMEM.
|
||
*
|
||
* The TMA descriptor must be in device memory (GMEM).
|
||
* Only ONE thread per CTA should issue the TMA copy.
|
||
*
|
||
* After issue, the data will be written to SMEM asynchronously.
|
||
* Use tma_mbarrier_wait to wait for completion.
|
||
*
|
||
* @param smem_dst SMEM destination address (via __cvta_generic_to_shared)
|
||
* @param tma_desc Pointer to CUtensorMap in device memory (uint64_t cast)
|
||
* @param smem_mbar SMEM mbarrier address (via __cvta_generic_to_shared)
|
||
* @param coord_x Innermost dimension coordinate (columns)
|
||
* @param coord_y Middle dimension coordinate (rows)
|
||
* @param coord_z Outermost dimension coordinate (degenerate = 0)
|
||
*/
|
||
__device__ __forceinline__ void tma_load_3d(
|
||
uint32_t smem_dst,
|
||
uint64_t tma_desc,
|
||
uint32_t smem_mbar,
|
||
int coord_x,
|
||
int coord_y,
|
||
int coord_z = 0
|
||
) {
|
||
asm volatile(
|
||
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes "
|
||
"[%0], [%1, {%3, %4, %5}], [%2];"
|
||
:: "r"(smem_dst),
|
||
"l"(tma_desc),
|
||
"r"(smem_mbar),
|
||
"r"(coord_x),
|
||
"r"(coord_y),
|
||
"r"(coord_z)
|
||
: "memory"
|
||
);
|
||
}
|
||
|
||
/**
|
||
* Wait for mbarrier completion (spin-wait).
|
||
* Only ONE thread should wait (or all threads, but typically just the
|
||
* thread that issued the TMA copy).
|
||
*
|
||
* @param smem_mbar SMEM mbarrier address (via __cvta_generic_to_shared)
|
||
*/
|
||
__device__ __forceinline__ void tma_mbarrier_wait(uint32_t smem_mbar) {
|
||
int phase = 0;
|
||
asm volatile(
|
||
"{\n\t"
|
||
".reg .pred p;\n\t"
|
||
"LOOP:\n\t"
|
||
"mbarrier.try_wait.parity.shared.b64 p, [%0], %1;\n\t"
|
||
"@p bra DONE;\n\t"
|
||
"bra LOOP;\n\t"
|
||
"DONE:\n\t"
|
||
"}"
|
||
:: "r"(smem_mbar), "r"(phase)
|
||
: "memory"
|
||
);
|
||
}
|
||
|
||
/**
|
||
* Invalidate L2 prefetch to ensure TMA sees fresh data.
|
||
* Call before issuing TMA loads if the data was recently written.
|
||
*/
|
||
__device__ __forceinline__ void tma_cp_commit() {
|
||
asm volatile("cp.async.commit_group;" ::: "memory");
|
||
}
|
||
|
||
// ==================================================================
|
||
// TMA parameter structure
|
||
// ==================================================================
|
||
|
||
struct FmhaTmaDescriptors {
|
||
CUtensorMap* __restrict__ tma_q; // Q descriptor: (T, HD) row-major
|
||
CUtensorMap* __restrict__ tma_k; // K descriptor: (s_k, HD) row-major
|
||
CUtensorMap* __restrict__ tma_v; // V descriptor: (HD, s_k) row-major
|
||
};
|
||
|
||
} // namespace dsv4::kernels::attention
|