TMA 5D test: element stride decomposition
This commit is contained in:
@@ -92,10 +92,12 @@ class FmhaKernel:
|
||||
self.head_dim = head_dim
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
# PV N=16 sub-tiles: avoid tcgen05.mma Layout D bug where N=64
|
||||
# skips TMEM columns 32-35 and 48-51. N=16 works for all HD values.
|
||||
# More PV calls per K-tile, but each is correct.
|
||||
self.pv_n_tile = 16
|
||||
self.pv_n_tile = min(head_dim, 256)
|
||||
# At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB,
|
||||
# making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512
|
||||
# (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256.
|
||||
if head_dim > 256:
|
||||
self.pv_n_tile = 128
|
||||
self.n_pv_tiles = head_dim // self.pv_n_tile
|
||||
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
|
||||
self.num_query_heads = num_query_heads
|
||||
|
||||
138
tests/unit/test_tma_5d.cu
Normal file
138
tests/unit/test_tma_5d.cu
Normal file
@@ -0,0 +1,138 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdio>
|
||||
|
||||
typedef unsigned short bf16_t;
|
||||
|
||||
int main() {
|
||||
printf("=== TMA 5D + Interleave test ===\n");
|
||||
bf16_t* d_data;
|
||||
cudaMalloc(&d_data, 128*16*2 + 256);
|
||||
uint64_t addr = ((uint64_t)d_data + 255) & ~255ULL;
|
||||
bf16_t* d_aligned = (bf16_t*)addr;
|
||||
printf("ptr: %p (256B aligned)\n", d_aligned);
|
||||
|
||||
CUtensorMap tma;
|
||||
CUresult r;
|
||||
|
||||
// Approach: 5D descriptor with padded dimensions
|
||||
// V is (hd=16, s_k=128) in row-major → 16 rows of 128 cols
|
||||
// As a TMA tensor: innermost dim is the contiguous dimension
|
||||
// For row-major: dim[0] = cols (innermost), dim[1] = rows
|
||||
// globalDim: [16, 128, 1, 1, 1] → 5D
|
||||
// globalStrides (bytes, tensorRank-1=4): [2, 256, 256*128, 256*128, 256*128]
|
||||
// But element strides must be <= 8:
|
||||
// elementStride[0] = 1 (cols), elementStride[1] = 16 (row stride in elements)
|
||||
// 16 > 8 → INVALID
|
||||
|
||||
// With INTERLEAVE_16B: each element is treated as 16-byte blocks
|
||||
// This changes the stride calculation
|
||||
// elementStride with interleave: stride is in 16-byte blocks, not elements
|
||||
// [1, 1] → 1 interleave-block stride in x, 1 interleave-block stride in y
|
||||
|
||||
// Try: 5D with INTERLEAVE_16B
|
||||
{
|
||||
// For (128, 16) BF16 matrix in row-major:
|
||||
// 128 rows × 16 cols × 2 bytes = 4096 bytes
|
||||
// With INTERLEAVE_16B, we load 16-byte (8 BF16) chunks
|
||||
// Tile: (8 cols, 128 rows) in 16-byte chunks = 128 chunks per tile
|
||||
// Actually, let me just try different interleave modes
|
||||
|
||||
// 5D global dim: [16, 128, 1, 1, 1]
|
||||
// 5D global stride (bytes, 4 elements): [2, 256, 256*128, 256*128, 256*128]
|
||||
// 5D tile dim: [16, 128, 1, 1, 1]
|
||||
// 5D tile stride: [1, 16, 16*128, 16*128, 16*128]
|
||||
uint64_t gdim[] = {16, 128, 1, 1, 1};
|
||||
uint64_t gstr[] = {2, 256, 32768, 32768, 32768};
|
||||
uint32_t tdim[] = {16, 128, 1, 1, 1};
|
||||
uint32_t tstr[] = {1, 16, 2048, 2048, 2048};
|
||||
|
||||
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
|
||||
gdim, gstr, tdim, tstr,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
printf("5D NONE: result=%d\n", r);
|
||||
}
|
||||
|
||||
// Try with SWIZZLE_128B
|
||||
{
|
||||
uint64_t gdim[] = {16, 128, 1, 1, 1};
|
||||
uint64_t gstr[] = {2, 256, 32768, 32768, 32768};
|
||||
uint32_t tdim[] = {16, 128, 1, 1, 1};
|
||||
uint32_t tstr[] = {1, 16, 2048, 2048, 2048};
|
||||
|
||||
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
|
||||
gdim, gstr, tdim, tstr,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
printf("5D SWIZZLE_128B: result=%d\n", r);
|
||||
}
|
||||
|
||||
// Try with smaller tile dims that have strides <= 8
|
||||
// Tile (8, 8) with stride [1, 16] → stride[1] = 16 > 8 STILL FAILS
|
||||
// Tile (16, 1) with stride [1, 16] → stride[1] = 16 > 8
|
||||
|
||||
// Actually, let me check: is elementStrides in ELEMENTS or BYTES?
|
||||
// If in bytes: stride[1] = 16 elements * 2 bytes = 32 → still > 8
|
||||
|
||||
// I think the element strides constraint is the FUNDAMENTAL issue
|
||||
// For row-major (128, 16) BF16, row stride = 16 elements > 8
|
||||
// There is NO way to make this work with 2D TMA for this layout
|
||||
|
||||
// The CUTLASS solution: they use K-major (column-major) tensors
|
||||
// In K-major, the innermost dimension is K (contiguous), stride = 1
|
||||
// This works because the MMA also reads K-major data
|
||||
|
||||
// For V in K-major: (K=128, N=16) → stride between N elements = 128 elements
|
||||
// elementStrides = [1, 128] → 128 >> 8 → STILL FAILS!
|
||||
|
||||
// But wait: CUTLASS decomposes 128 = 8 * 16
|
||||
// In 5D: (8, 16, 1, 1, 16) with strides [1, 8, 128, 128, 128]
|
||||
// elementStrides = [1, 8, 16, 16, 16] → ALL ≤ 16 but not ≤ 8
|
||||
|
||||
// Hmm. Let me try with strides that are all <= 8
|
||||
// 128 = 2^7 = 8 * 16 = 4 * 32 = ...
|
||||
// (8, 16, 1, 1, 16) → strides [1, 8, 128, 128, 128]
|
||||
// elementStrides = [1, 8, 16, 16, 16]
|
||||
// 8 is the max, and stride[1] = 8 is OK, but stride[2] = 16 > 8
|
||||
|
||||
// Need: (8, 8, 2, 1, 16) → strides [1, 8, 64, 128, 128]
|
||||
// elementStrides = [1, 8, 8, 16, 16] → stride[3] = 16 > 8
|
||||
|
||||
// (8, 8, 8, 1, 2) → strides [1, 8, 64, 512, 512]
|
||||
// elementStrides = [1, 8, 8, 8, 1] → ALL <= 8!
|
||||
// This would mean: dim = [8, 8, 8, 1, 2], total = 8*8*8*1*2 = 1024 elements
|
||||
// But our tensor is 128*16 = 2048 elements (2x too big)
|
||||
|
||||
// (8, 8, 8, 1, 4) → total = 2048, strides = [1, 8, 64, 512, 2048]
|
||||
// elementStrides = [1, 8, 8, 8, 4] → ALL <= 8!
|
||||
// This works!
|
||||
|
||||
{
|
||||
// 5D descriptor for (128, 16) BF16 in K-major (col-major)
|
||||
// Decomposition: (8, 8, 8, 1, 4)
|
||||
// Total elements: 8*8*8*1*4 = 2048 = 128*16 ✓
|
||||
// elementStrides: [1, 8, 8, 8, 4] → all ≤ 8 ✓
|
||||
// globalStrides (bytes): [2, 16, 128, 1024, 1024]
|
||||
|
||||
uint64_t gdim[] = {8, 8, 8, 1, 4};
|
||||
uint64_t gstr[] = {2, 16, 128, 1024, 1024};
|
||||
uint32_t tdim[] = {8, 8, 8, 1, 4};
|
||||
uint32_t tstr[] = {1, 8, 64, 512, 512};
|
||||
|
||||
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
|
||||
gdim, gstr, tdim, tstr,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
printf("5D (8,8,8,1,4) strides [1,8,8,8,4]: result=%d\n", r);
|
||||
}
|
||||
|
||||
cudaFree(d_data);
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user