139 lines
5.7 KiB
Plaintext
139 lines
5.7 KiB
Plaintext
#include <cuda.h>
|
||
#include <cuda_runtime.h>
|
||
#include <cstdio>
|
||
|
||
typedef unsigned short bf16_t;
|
||
|
||
int main() {
|
||
printf("=== TMA 5D + Interleave test ===\n");
|
||
bf16_t* d_data;
|
||
cudaMalloc(&d_data, 128*16*2 + 256);
|
||
uint64_t addr = ((uint64_t)d_data + 255) & ~255ULL;
|
||
bf16_t* d_aligned = (bf16_t*)addr;
|
||
printf("ptr: %p (256B aligned)\n", d_aligned);
|
||
|
||
CUtensorMap tma;
|
||
CUresult r;
|
||
|
||
// Approach: 5D descriptor with padded dimensions
|
||
// V is (hd=16, s_k=128) in row-major → 16 rows of 128 cols
|
||
// As a TMA tensor: innermost dim is the contiguous dimension
|
||
// For row-major: dim[0] = cols (innermost), dim[1] = rows
|
||
// globalDim: [16, 128, 1, 1, 1] → 5D
|
||
// globalStrides (bytes, tensorRank-1=4): [2, 256, 256*128, 256*128, 256*128]
|
||
// But element strides must be <= 8:
|
||
// elementStride[0] = 1 (cols), elementStride[1] = 16 (row stride in elements)
|
||
// 16 > 8 → INVALID
|
||
|
||
// With INTERLEAVE_16B: each element is treated as 16-byte blocks
|
||
// This changes the stride calculation
|
||
// elementStride with interleave: stride is in 16-byte blocks, not elements
|
||
// [1, 1] → 1 interleave-block stride in x, 1 interleave-block stride in y
|
||
|
||
// Try: 5D with INTERLEAVE_16B
|
||
{
|
||
// For (128, 16) BF16 matrix in row-major:
|
||
// 128 rows × 16 cols × 2 bytes = 4096 bytes
|
||
// With INTERLEAVE_16B, we load 16-byte (8 BF16) chunks
|
||
// Tile: (8 cols, 128 rows) in 16-byte chunks = 128 chunks per tile
|
||
// Actually, let me just try different interleave modes
|
||
|
||
// 5D global dim: [16, 128, 1, 1, 1]
|
||
// 5D global stride (bytes, 4 elements): [2, 256, 256*128, 256*128, 256*128]
|
||
// 5D tile dim: [16, 128, 1, 1, 1]
|
||
// 5D tile stride: [1, 16, 16*128, 16*128, 16*128]
|
||
uint64_t gdim[] = {16, 128, 1, 1, 1};
|
||
uint64_t gstr[] = {2, 256, 32768, 32768, 32768};
|
||
uint32_t tdim[] = {16, 128, 1, 1, 1};
|
||
uint32_t tstr[] = {1, 16, 2048, 2048, 2048};
|
||
|
||
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
|
||
gdim, gstr, tdim, tstr,
|
||
CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||
CU_TENSOR_MAP_SWIZZLE_NONE,
|
||
CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||
printf("5D NONE: result=%d\n", r);
|
||
}
|
||
|
||
// Try with SWIZZLE_128B
|
||
{
|
||
uint64_t gdim[] = {16, 128, 1, 1, 1};
|
||
uint64_t gstr[] = {2, 256, 32768, 32768, 32768};
|
||
uint32_t tdim[] = {16, 128, 1, 1, 1};
|
||
uint32_t tstr[] = {1, 16, 2048, 2048, 2048};
|
||
|
||
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
|
||
gdim, gstr, tdim, tstr,
|
||
CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||
CU_TENSOR_MAP_SWIZZLE_128B,
|
||
CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||
printf("5D SWIZZLE_128B: result=%d\n", r);
|
||
}
|
||
|
||
// Try with smaller tile dims that have strides <= 8
|
||
// Tile (8, 8) with stride [1, 16] → stride[1] = 16 > 8 STILL FAILS
|
||
// Tile (16, 1) with stride [1, 16] → stride[1] = 16 > 8
|
||
|
||
// Actually, let me check: is elementStrides in ELEMENTS or BYTES?
|
||
// If in bytes: stride[1] = 16 elements * 2 bytes = 32 → still > 8
|
||
|
||
// I think the element strides constraint is the FUNDAMENTAL issue
|
||
// For row-major (128, 16) BF16, row stride = 16 elements > 8
|
||
// There is NO way to make this work with 2D TMA for this layout
|
||
|
||
// The CUTLASS solution: they use K-major (column-major) tensors
|
||
// In K-major, the innermost dimension is K (contiguous), stride = 1
|
||
// This works because the MMA also reads K-major data
|
||
|
||
// For V in K-major: (K=128, N=16) → stride between N elements = 128 elements
|
||
// elementStrides = [1, 128] → 128 >> 8 → STILL FAILS!
|
||
|
||
// But wait: CUTLASS decomposes 128 = 8 * 16
|
||
// In 5D: (8, 16, 1, 1, 16) with strides [1, 8, 128, 128, 128]
|
||
// elementStrides = [1, 8, 16, 16, 16] → ALL ≤ 16 but not ≤ 8
|
||
|
||
// Hmm. Let me try with strides that are all <= 8
|
||
// 128 = 2^7 = 8 * 16 = 4 * 32 = ...
|
||
// (8, 16, 1, 1, 16) → strides [1, 8, 128, 128, 128]
|
||
// elementStrides = [1, 8, 16, 16, 16]
|
||
// 8 is the max, and stride[1] = 8 is OK, but stride[2] = 16 > 8
|
||
|
||
// Need: (8, 8, 2, 1, 16) → strides [1, 8, 64, 128, 128]
|
||
// elementStrides = [1, 8, 8, 16, 16] → stride[3] = 16 > 8
|
||
|
||
// (8, 8, 8, 1, 2) → strides [1, 8, 64, 512, 512]
|
||
// elementStrides = [1, 8, 8, 8, 1] → ALL <= 8!
|
||
// This would mean: dim = [8, 8, 8, 1, 2], total = 8*8*8*1*2 = 1024 elements
|
||
// But our tensor is 128*16 = 2048 elements (2x too big)
|
||
|
||
// (8, 8, 8, 1, 4) → total = 2048, strides = [1, 8, 64, 512, 2048]
|
||
// elementStrides = [1, 8, 8, 8, 4] → ALL <= 8!
|
||
// This works!
|
||
|
||
{
|
||
// 5D descriptor for (128, 16) BF16 in K-major (col-major)
|
||
// Decomposition: (8, 8, 8, 1, 4)
|
||
// Total elements: 8*8*8*1*4 = 2048 = 128*16 ✓
|
||
// elementStrides: [1, 8, 8, 8, 4] → all ≤ 8 ✓
|
||
// globalStrides (bytes): [2, 16, 128, 1024, 1024]
|
||
|
||
uint64_t gdim[] = {8, 8, 8, 1, 4};
|
||
uint64_t gstr[] = {2, 16, 128, 1024, 1024};
|
||
uint32_t tdim[] = {8, 8, 8, 1, 4};
|
||
uint32_t tstr[] = {1, 8, 64, 512, 512};
|
||
|
||
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
|
||
gdim, gstr, tdim, tstr,
|
||
CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||
CU_TENSOR_MAP_SWIZZLE_NONE,
|
||
CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||
printf("5D (8,8,8,1,4) strides [1,8,8,8,4]: result=%d\n", r);
|
||
}
|
||
|
||
cudaFree(d_data);
|
||
return 0;
|
||
}
|