TMA 5D test: element stride decomposition

This commit is contained in:
2026-05-28 19:18:01 +00:00
parent 96f2f0bb90
commit 6af2feb42a
2 changed files with 144 additions and 4 deletions

View File

@@ -92,10 +92,12 @@ class FmhaKernel:
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
# PV N=16 sub-tiles: avoid tcgen05.mma Layout D bug where N=64
# skips TMEM columns 32-35 and 48-51. N=16 works for all HD values.
# More PV calls per K-tile, but each is correct.
self.pv_n_tile = 16
self.pv_n_tile = min(head_dim, 256)
# At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB,
# making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512
# (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256.
if head_dim > 256:
self.pv_n_tile = 128
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.num_query_heads = num_query_heads

138
tests/unit/test_tma_5d.cu Normal file
View File

@@ -0,0 +1,138 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdio>
typedef unsigned short bf16_t;
int main() {
printf("=== TMA 5D + Interleave test ===\n");
bf16_t* d_data;
cudaMalloc(&d_data, 128*16*2 + 256);
uint64_t addr = ((uint64_t)d_data + 255) & ~255ULL;
bf16_t* d_aligned = (bf16_t*)addr;
printf("ptr: %p (256B aligned)\n", d_aligned);
CUtensorMap tma;
CUresult r;
// Approach: 5D descriptor with padded dimensions
// V is (hd=16, s_k=128) in row-major → 16 rows of 128 cols
// As a TMA tensor: innermost dim is the contiguous dimension
// For row-major: dim[0] = cols (innermost), dim[1] = rows
// globalDim: [16, 128, 1, 1, 1] → 5D
// globalStrides (bytes, tensorRank-1=4): [2, 256, 256*128, 256*128, 256*128]
// But element strides must be <= 8:
// elementStride[0] = 1 (cols), elementStride[1] = 16 (row stride in elements)
// 16 > 8 → INVALID
// With INTERLEAVE_16B: each element is treated as 16-byte blocks
// This changes the stride calculation
// elementStride with interleave: stride is in 16-byte blocks, not elements
// [1, 1] → 1 interleave-block stride in x, 1 interleave-block stride in y
// Try: 5D with INTERLEAVE_16B
{
// For (128, 16) BF16 matrix in row-major:
// 128 rows × 16 cols × 2 bytes = 4096 bytes
// With INTERLEAVE_16B, we load 16-byte (8 BF16) chunks
// Tile: (8 cols, 128 rows) in 16-byte chunks = 128 chunks per tile
// Actually, let me just try different interleave modes
// 5D global dim: [16, 128, 1, 1, 1]
// 5D global stride (bytes, 4 elements): [2, 256, 256*128, 256*128, 256*128]
// 5D tile dim: [16, 128, 1, 1, 1]
// 5D tile stride: [1, 16, 16*128, 16*128, 16*128]
uint64_t gdim[] = {16, 128, 1, 1, 1};
uint64_t gstr[] = {2, 256, 32768, 32768, 32768};
uint32_t tdim[] = {16, 128, 1, 1, 1};
uint32_t tstr[] = {1, 16, 2048, 2048, 2048};
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
gdim, gstr, tdim, tstr,
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_NONE,
CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
printf("5D NONE: result=%d\n", r);
}
// Try with SWIZZLE_128B
{
uint64_t gdim[] = {16, 128, 1, 1, 1};
uint64_t gstr[] = {2, 256, 32768, 32768, 32768};
uint32_t tdim[] = {16, 128, 1, 1, 1};
uint32_t tstr[] = {1, 16, 2048, 2048, 2048};
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
gdim, gstr, tdim, tstr,
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
printf("5D SWIZZLE_128B: result=%d\n", r);
}
// Try with smaller tile dims that have strides <= 8
// Tile (8, 8) with stride [1, 16] → stride[1] = 16 > 8 STILL FAILS
// Tile (16, 1) with stride [1, 16] → stride[1] = 16 > 8
// Actually, let me check: is elementStrides in ELEMENTS or BYTES?
// If in bytes: stride[1] = 16 elements * 2 bytes = 32 → still > 8
// I think the element strides constraint is the FUNDAMENTAL issue
// For row-major (128, 16) BF16, row stride = 16 elements > 8
// There is NO way to make this work with 2D TMA for this layout
// The CUTLASS solution: they use K-major (column-major) tensors
// In K-major, the innermost dimension is K (contiguous), stride = 1
// This works because the MMA also reads K-major data
// For V in K-major: (K=128, N=16) → stride between N elements = 128 elements
// elementStrides = [1, 128] → 128 >> 8 → STILL FAILS!
// But wait: CUTLASS decomposes 128 = 8 * 16
// In 5D: (8, 16, 1, 1, 16) with strides [1, 8, 128, 128, 128]
// elementStrides = [1, 8, 16, 16, 16] → ALL ≤ 16 but not ≤ 8
// Hmm. Let me try with strides that are all <= 8
// 128 = 2^7 = 8 * 16 = 4 * 32 = ...
// (8, 16, 1, 1, 16) → strides [1, 8, 128, 128, 128]
// elementStrides = [1, 8, 16, 16, 16]
// 8 is the max, and stride[1] = 8 is OK, but stride[2] = 16 > 8
// Need: (8, 8, 2, 1, 16) → strides [1, 8, 64, 128, 128]
// elementStrides = [1, 8, 8, 16, 16] → stride[3] = 16 > 8
// (8, 8, 8, 1, 2) → strides [1, 8, 64, 512, 512]
// elementStrides = [1, 8, 8, 8, 1] → ALL <= 8!
// This would mean: dim = [8, 8, 8, 1, 2], total = 8*8*8*1*2 = 1024 elements
// But our tensor is 128*16 = 2048 elements (2x too big)
// (8, 8, 8, 1, 4) → total = 2048, strides = [1, 8, 64, 512, 2048]
// elementStrides = [1, 8, 8, 8, 4] → ALL <= 8!
// This works!
{
// 5D descriptor for (128, 16) BF16 in K-major (col-major)
// Decomposition: (8, 8, 8, 1, 4)
// Total elements: 8*8*8*1*4 = 2048 = 128*16 ✓
// elementStrides: [1, 8, 8, 8, 4] → all ≤ 8 ✓
// globalStrides (bytes): [2, 16, 128, 1024, 1024]
uint64_t gdim[] = {8, 8, 8, 1, 4};
uint64_t gstr[] = {2, 16, 128, 1024, 1024};
uint32_t tdim[] = {8, 8, 8, 1, 4};
uint32_t tstr[] = {1, 8, 64, 512, 512};
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
gdim, gstr, tdim, tstr,
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_NONE,
CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
printf("5D (8,8,8,1,4) strides [1,8,8,8,4]: result=%d\n", r);
}
cudaFree(d_data);
return 0;
}