From 6af2feb42aa43d7d4f89ec4ab9b5cbed7b7655a7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 19:18:01 +0000 Subject: [PATCH] TMA 5D test: element stride decomposition --- dsv4/kernels/attention/fmha.py | 10 ++- tests/unit/test_tma_5d.cu | 138 +++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_tma_5d.cu diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 23f071d6..e13b8cfd 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -92,10 +92,12 @@ class FmhaKernel: self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 - # PV N=16 sub-tiles: avoid tcgen05.mma Layout D bug where N=64 - # skips TMEM columns 32-35 and 48-51. N=16 works for all HD values. - # More PV calls per K-tile, but each is correct. - self.pv_n_tile = 16 + self.pv_n_tile = min(head_dim, 256) + # At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB, + # making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512 + # (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256. + if head_dim > 256: + self.pv_n_tile = 128 self.n_pv_tiles = head_dim // self.pv_n_tile self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) self.num_query_heads = num_query_heads diff --git a/tests/unit/test_tma_5d.cu b/tests/unit/test_tma_5d.cu new file mode 100644 index 00000000..5eab1948 --- /dev/null +++ b/tests/unit/test_tma_5d.cu @@ -0,0 +1,138 @@ +#include +#include +#include + +typedef unsigned short bf16_t; + +int main() { + printf("=== TMA 5D + Interleave test ===\n"); + bf16_t* d_data; + cudaMalloc(&d_data, 128*16*2 + 256); + uint64_t addr = ((uint64_t)d_data + 255) & ~255ULL; + bf16_t* d_aligned = (bf16_t*)addr; + printf("ptr: %p (256B aligned)\n", d_aligned); + + CUtensorMap tma; + CUresult r; + + // Approach: 5D descriptor with padded dimensions + // V is (hd=16, s_k=128) in row-major → 16 rows of 128 cols + // As a TMA tensor: innermost dim is the contiguous dimension + // For row-major: dim[0] = cols (innermost), dim[1] = rows + // globalDim: [16, 128, 1, 1, 1] → 5D + // globalStrides (bytes, tensorRank-1=4): [2, 256, 256*128, 256*128, 256*128] + // But element strides must be <= 8: + // elementStride[0] = 1 (cols), elementStride[1] = 16 (row stride in elements) + // 16 > 8 → INVALID + + // With INTERLEAVE_16B: each element is treated as 16-byte blocks + // This changes the stride calculation + // elementStride with interleave: stride is in 16-byte blocks, not elements + // [1, 1] → 1 interleave-block stride in x, 1 interleave-block stride in y + + // Try: 5D with INTERLEAVE_16B + { + // For (128, 16) BF16 matrix in row-major: + // 128 rows × 16 cols × 2 bytes = 4096 bytes + // With INTERLEAVE_16B, we load 16-byte (8 BF16) chunks + // Tile: (8 cols, 128 rows) in 16-byte chunks = 128 chunks per tile + // Actually, let me just try different interleave modes + + // 5D global dim: [16, 128, 1, 1, 1] + // 5D global stride (bytes, 4 elements): [2, 256, 256*128, 256*128, 256*128] + // 5D tile dim: [16, 128, 1, 1, 1] + // 5D tile stride: [1, 16, 16*128, 16*128, 16*128] + uint64_t gdim[] = {16, 128, 1, 1, 1}; + uint64_t gstr[] = {2, 256, 32768, 32768, 32768}; + uint32_t tdim[] = {16, 128, 1, 1, 1}; + uint32_t tstr[] = {1, 16, 2048, 2048, 2048}; + + r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned, + gdim, gstr, tdim, tstr, + CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + printf("5D NONE: result=%d\n", r); + } + + // Try with SWIZZLE_128B + { + uint64_t gdim[] = {16, 128, 1, 1, 1}; + uint64_t gstr[] = {2, 256, 32768, 32768, 32768}; + uint32_t tdim[] = {16, 128, 1, 1, 1}; + uint32_t tstr[] = {1, 16, 2048, 2048, 2048}; + + r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned, + gdim, gstr, tdim, tstr, + CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_128B, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + printf("5D SWIZZLE_128B: result=%d\n", r); + } + + // Try with smaller tile dims that have strides <= 8 + // Tile (8, 8) with stride [1, 16] → stride[1] = 16 > 8 STILL FAILS + // Tile (16, 1) with stride [1, 16] → stride[1] = 16 > 8 + + // Actually, let me check: is elementStrides in ELEMENTS or BYTES? + // If in bytes: stride[1] = 16 elements * 2 bytes = 32 → still > 8 + + // I think the element strides constraint is the FUNDAMENTAL issue + // For row-major (128, 16) BF16, row stride = 16 elements > 8 + // There is NO way to make this work with 2D TMA for this layout + + // The CUTLASS solution: they use K-major (column-major) tensors + // In K-major, the innermost dimension is K (contiguous), stride = 1 + // This works because the MMA also reads K-major data + + // For V in K-major: (K=128, N=16) → stride between N elements = 128 elements + // elementStrides = [1, 128] → 128 >> 8 → STILL FAILS! + + // But wait: CUTLASS decomposes 128 = 8 * 16 + // In 5D: (8, 16, 1, 1, 16) with strides [1, 8, 128, 128, 128] + // elementStrides = [1, 8, 16, 16, 16] → ALL ≤ 16 but not ≤ 8 + + // Hmm. Let me try with strides that are all <= 8 + // 128 = 2^7 = 8 * 16 = 4 * 32 = ... + // (8, 16, 1, 1, 16) → strides [1, 8, 128, 128, 128] + // elementStrides = [1, 8, 16, 16, 16] + // 8 is the max, and stride[1] = 8 is OK, but stride[2] = 16 > 8 + + // Need: (8, 8, 2, 1, 16) → strides [1, 8, 64, 128, 128] + // elementStrides = [1, 8, 8, 16, 16] → stride[3] = 16 > 8 + + // (8, 8, 8, 1, 2) → strides [1, 8, 64, 512, 512] + // elementStrides = [1, 8, 8, 8, 1] → ALL <= 8! + // This would mean: dim = [8, 8, 8, 1, 2], total = 8*8*8*1*2 = 1024 elements + // But our tensor is 128*16 = 2048 elements (2x too big) + + // (8, 8, 8, 1, 4) → total = 2048, strides = [1, 8, 64, 512, 2048] + // elementStrides = [1, 8, 8, 8, 4] → ALL <= 8! + // This works! + + { + // 5D descriptor for (128, 16) BF16 in K-major (col-major) + // Decomposition: (8, 8, 8, 1, 4) + // Total elements: 8*8*8*1*4 = 2048 = 128*16 ✓ + // elementStrides: [1, 8, 8, 8, 4] → all ≤ 8 ✓ + // globalStrides (bytes): [2, 16, 128, 1024, 1024] + + uint64_t gdim[] = {8, 8, 8, 1, 4}; + uint64_t gstr[] = {2, 16, 128, 1024, 1024}; + uint32_t tdim[] = {8, 8, 8, 1, 4}; + uint32_t tstr[] = {1, 8, 64, 512, 512}; + + r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned, + gdim, gstr, tdim, tstr, + CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + printf("5D (8,8,8,1,4) strides [1,8,8,8,4]: result=%d\n", r); + } + + cudaFree(d_data); + return 0; +}