Files
nvfp4-megamoe-kernel/archived_plans/code_archive/fmha_tma_driver_api.cuh
biondizzle 4b9eed02e1 Cleanup C1-C7: delete dead CuTeDSL FMHA, test probes, scratch files
- Deleted fmha.py (CuTeDSL slow path), FmhaKernel, Python KV merge
- Deleted fmha_sm100.cuh, fmha_sm100_tc.cuh, fmha_sm100_launch.cu, fmha_epilogue_sm100.cuh
- Moved fmha_qk_verify.cuh to tests/unit/qk_verify_kernel.cuh
- Deleted decode_sparse.py, decode_swa.py, kernels/decode/
- Deleted 46 test_d*.py probes, test_smem_*, test_cotiled_*, test_tmem_*,
  test_smem_p_*, test_ultra_minimal, test_fmha_pv16, test_working_softmax_maybe
- Deleted root scratch: debug_linear.py, test_mapping.py, run_router_tests.py
- Moved archive/ to archived_plans/code_archive/
- Rewrote production.py: single fast path via 6-warp multi-tile kernel
- Added STATUS.md, audit_attention_live.md
- Moved NEXT_PRIORITIES*.md to archived_plans/
2026-05-30 21:08:12 +00:00

258 lines
9.9 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* DSV4 FMHA — TMA async load infrastructure for Blackwell SM100.
*
* ==================================================================
* DESIGN
* ==================================================================
*
* Replaces scalar GMEM reads in the load warp with async TMA bulk
* copies via cp.async.bulk.tensor.3d. The pipeline:
*
* Host: CUtensorMap creation for Q, K, V tiles
* Kernel:
* 1. TMA warp issues cp.async.bulk.tensor.3d → SMEM (row-major)
* 2. mbarrier wait for TMA completion
* 3. Load warp transposes row-major SMEM → canonical K-major SMEM
* 4. MMA warp runs tcgen05.mma as before
*
* For double-buffered pipeline overlap (future):
* - Two SMEM buffers per tensor (sQ0/sQ1, sK0/sK1)
* - TMA load of K-tile (kt+1) overlaps with MMA on K-tile (kt)
* - Pipeline stages managed via mbarrier arrive/wait
*
* ==================================================================
* TMA DESCRIPTOR LAYOUT
* ==================================================================
*
* We create 2D CUtensorMap descriptors for each tile the kernel needs:
*
* Q tile: (T, HD) — one tile for the full Q
* K tile: (s_k, HD) — one tile for the full K (or (128, 16) per K-sub-tile)
* V tile: (HD, s_k) — transposed, one tile for the full V
*
* TMA copies data from GMEM to SMEM in row-major order. After TMA
* completion, the load warp transposes from row-major to the
* canonical K-major core-matrix layout that tcgen05.mma expects.
*
* For the multirow kernel, Q is (T, HD) and K is (s_k, HD).
* Since TMA operates on 2D tiles and our SMEM is (128, 16) per
* MMA K-tile, we have two choices:
*
* Option A: TMA load full (T, HD) → row-major SMEM → transpose
* - One TMA descriptor for Q, one for K
* - Larger SMEM footprint (need row-major + canonical)
* - Simpler descriptor management
*
* Option B: TMA load per (128, 16) K-sub-tile
* - One TMA descriptor, multiple TMA issues with different coords
* - Same SMEM as current (no double buffer needed for single-stage)
* - Matches the existing K-tiling loop structure
*
* We choose Option B: TMA per (128, 16) K-sub-tile. This:
* - Reuses the exact same SMEM layout as the current kernel
* - Fits the existing QK loop structure (kt = 0..NKT_QK-1)
* - Enables future pipeline overlap with minimal changes
* - The TMA descriptor covers the full (T, HD) or (s_k, HD) tensor,
* and we issue TMA loads for specific (col, row) coordinates
* targeting each 128×16 tile
*
* ==================================================================
* MBARRIER PROTOCOL
* ==================================================================
*
* TMA async copies use mbarrier for completion signaling:
*
* 1. Init mbarrier with expected transaction count = 1
* 2. Issue cp.async.bulk.tensor.3d with the mbarrier
* 3. Wait on mbarrier parity (spin or yield)
* 4. After wait returns, SMEM data is ready
*
* The mbarrier lives in SMEM. One mbarrier per outstanding TMA
* operation. For single-stage (no overlap), we use one mbarrier
* and wait immediately after issue.
*
* ==================================================================
* SWIZZLE CONSIDERATIONS
* ==================================================================
*
* TMA descriptors support SWIZZLE_NONE, SWIZZLE_32B, SWIZZLE_64B,
* SWIZZLE_128B. The swizzle pattern in SMEM matches what UMMA
* descriptors expect when using make_umma_desc_kmajor_sw128.
*
* Current kernel uses SWIZZLE_NONE (make_umma_desc_kmajor_none).
* With TMA, we have two paths:
*
* Path 1: TMA with SWIZZLE_NONE → SMEM is row-major → transpose to canonical
* Path 2: TMA with SWIZZLE_128B → SMEM is swizzled → UMMA reads directly
*
* Path 2 is the production target: no transpose needed, TMA writes
* in the exact layout MMA reads. But getting the swizzle right is
* tricky and needs careful verification.
*
* We start with Path 1 (SWIZZLE_NONE + transpose) to get TMA working,
* then upgrade to Path 2 (SWIZZLE_128B, zero-copy) for performance.
* ==================================================================
*/
#pragma once
#include "fmha_common.cuh"
#include <cstdint>
namespace dsv4::kernels::attention {
// ==================================================================
// TMA descriptor helpers (host-side)
// ==================================================================
// These are called from host code to create CUtensorMap objects
// that the kernel uses for TMA async copies.
// ==================================================================
/**
* Create a 2D TMA descriptor for a BF16 tensor of shape (rows, cols).
* The tensor is row-major in GMEM with stride = cols.
* TMA tile dimensions are (tile_rows, tile_cols).
*
* The descriptor is written to `out` (host memory).
* Must be copied to device memory before kernel launch.
*/
inline bool create_tma_desc_2d_bf16(
CUtensorMap* out,
const void* gmem_ptr, // device pointer to the BF16 tensor
uint64_t rows, // global dimension 0 (number of rows)
uint64_t cols, // global dimension 1 (number of columns)
uint32_t tile_rows, // TMA tile dimension 0
uint32_t tile_cols, // TMA tile dimension 1
CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE
) {
// CUDA 13: globalStrides are in BYTES, not elements!
// globalStrides[0] = globalDim[0] * elementSizeInBytes = cols * 2 (BF16)
// globalStrides[1] = globalStrides[0] * globalDim[1] = cols * 2 * rows
// But for rank=2, only 1 stride is needed (rank-1)
//
// For rank=3 (recommended for CUDA 13), 2 strides are needed.
// We use rank=3 with degenerate 3rd dimension = 1.
// This avoids rank=2 edge cases and matches CUTLASS convention.
//
// 3D: (cols, rows, 1) innermost-first
uint64_t global_dim[] = {cols, rows, 1};
uint64_t global_str[] = {cols * 2, cols * 2 * rows}; // byte strides
// Tile: (tile_cols, tile_rows, 1) innermost-first
uint32_t tile_dim[] = {tile_cols, tile_rows, 1};
// Element strides within tile: 1 for each dim
uint32_t tile_str[] = {1, 1, 1};
CUresult res = cuTensorMapEncodeTiled(
out,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, // Proper BF16 type
3, // 3D tensor (degenerate 3rd dim)
const_cast<void*>(gmem_ptr),
global_dim, global_str, tile_dim, tile_str,
CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle,
CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
if (res != CUDA_SUCCESS) {
fprintf(stderr, "cuTensorMapEncodeTiled failed: error=%d, gdim=[%lu,%lu,%lu], gstr=[%lu,%lu], tdim=[%u,%u,%u], tstr=[%u,%u,%u]\n",
(int)res, global_dim[0], global_dim[1], global_dim[2], global_str[0], global_str[1],
tile_dim[0], tile_dim[1], tile_dim[2], tile_str[0], tile_str[1], tile_str[2]);
}
return res == CUDA_SUCCESS;
}
// ==================================================================
// TMA kernel-side operations
// ==================================================================
/**
* Initialize an mbarrier in SMEM with expected byte count.
* Only one thread should call this.
* For TMA with complete_tx::bytes, the expected count is the number
* of bytes that will be transferred.
*/
__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar, uint32_t expected_bytes) {
asm volatile("mbarrier.init.shared.b64 [%0], %1;"
:: "r"(smem_mbar), "r"(expected_bytes));
}
/**
* Issue a 3D TMA async copy from GMEM to SMEM.
*
* The TMA descriptor must be in device memory (GMEM).
* Only ONE thread per CTA should issue the TMA copy.
*
* After issue, the data will be written to SMEM asynchronously.
* Use tma_mbarrier_wait to wait for completion.
*
* @param smem_dst SMEM destination address (via __cvta_generic_to_shared)
* @param tma_desc Pointer to CUtensorMap in device memory (uint64_t cast)
* @param smem_mbar SMEM mbarrier address (via __cvta_generic_to_shared)
* @param coord_x Innermost dimension coordinate (columns)
* @param coord_y Middle dimension coordinate (rows)
* @param coord_z Outermost dimension coordinate (degenerate = 0)
*/
__device__ __forceinline__ void tma_load_3d(
uint32_t smem_dst,
uint64_t tma_desc,
uint32_t smem_mbar,
int coord_x,
int coord_y,
int coord_z = 0
) {
asm volatile(
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes "
"[%0], [%1, {%3, %4, %5}], [%2];"
:: "r"(smem_dst),
"l"(tma_desc),
"r"(smem_mbar),
"r"(coord_x),
"r"(coord_y),
"r"(coord_z)
: "memory"
);
}
/**
* Wait for mbarrier completion (spin-wait).
* Only ONE thread should wait (or all threads, but typically just the
* thread that issued the TMA copy).
*
* @param smem_mbar SMEM mbarrier address (via __cvta_generic_to_shared)
*/
__device__ __forceinline__ void tma_mbarrier_wait(uint32_t smem_mbar) {
int phase = 0;
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"LOOP:\n\t"
"mbarrier.try_wait.parity.shared.b64 p, [%0], %1;\n\t"
"@p bra DONE;\n\t"
"bra LOOP;\n\t"
"DONE:\n\t"
"}"
:: "r"(smem_mbar), "r"(phase)
: "memory"
);
}
/**
* Invalidate L2 prefetch to ensure TMA sees fresh data.
* Call before issuing TMA loads if the data was recently written.
*/
__device__ __forceinline__ void tma_cp_commit() {
asm volatile("cp.async.commit_group;" ::: "memory");
}
// ==================================================================
// TMA parameter structure
// ==================================================================
struct FmhaTmaDescriptors {
CUtensorMap* __restrict__ tma_q; // Q descriptor: (T, HD) row-major
CUtensorMap* __restrict__ tma_k; // K descriptor: (s_k, HD) row-major
CUtensorMap* __restrict__ tma_v; // V descriptor: (HD, s_k) row-major
};
} // namespace dsv4::kernels::attention