Files
nvfp4-megamoe-kernel/tests/unit/test_tma_5d.cu

139 lines
5.7 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdio>
typedef unsigned short bf16_t;
int main() {
printf("=== TMA 5D + Interleave test ===\n");
bf16_t* d_data;
cudaMalloc(&d_data, 128*16*2 + 256);
uint64_t addr = ((uint64_t)d_data + 255) & ~255ULL;
bf16_t* d_aligned = (bf16_t*)addr;
printf("ptr: %p (256B aligned)\n", d_aligned);
CUtensorMap tma;
CUresult r;
// Approach: 5D descriptor with padded dimensions
// V is (hd=16, s_k=128) in row-major → 16 rows of 128 cols
// As a TMA tensor: innermost dim is the contiguous dimension
// For row-major: dim[0] = cols (innermost), dim[1] = rows
// globalDim: [16, 128, 1, 1, 1] → 5D
// globalStrides (bytes, tensorRank-1=4): [2, 256, 256*128, 256*128, 256*128]
// But element strides must be <= 8:
// elementStride[0] = 1 (cols), elementStride[1] = 16 (row stride in elements)
// 16 > 8 → INVALID
// With INTERLEAVE_16B: each element is treated as 16-byte blocks
// This changes the stride calculation
// elementStride with interleave: stride is in 16-byte blocks, not elements
// [1, 1] → 1 interleave-block stride in x, 1 interleave-block stride in y
// Try: 5D with INTERLEAVE_16B
{
// For (128, 16) BF16 matrix in row-major:
// 128 rows × 16 cols × 2 bytes = 4096 bytes
// With INTERLEAVE_16B, we load 16-byte (8 BF16) chunks
// Tile: (8 cols, 128 rows) in 16-byte chunks = 128 chunks per tile
// Actually, let me just try different interleave modes
// 5D global dim: [16, 128, 1, 1, 1]
// 5D global stride (bytes, 4 elements): [2, 256, 256*128, 256*128, 256*128]
// 5D tile dim: [16, 128, 1, 1, 1]
// 5D tile stride: [1, 16, 16*128, 16*128, 16*128]
uint64_t gdim[] = {16, 128, 1, 1, 1};
uint64_t gstr[] = {2, 256, 32768, 32768, 32768};
uint32_t tdim[] = {16, 128, 1, 1, 1};
uint32_t tstr[] = {1, 16, 2048, 2048, 2048};
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
gdim, gstr, tdim, tstr,
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_NONE,
CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
printf("5D NONE: result=%d\n", r);
}
// Try with SWIZZLE_128B
{
uint64_t gdim[] = {16, 128, 1, 1, 1};
uint64_t gstr[] = {2, 256, 32768, 32768, 32768};
uint32_t tdim[] = {16, 128, 1, 1, 1};
uint32_t tstr[] = {1, 16, 2048, 2048, 2048};
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
gdim, gstr, tdim, tstr,
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
printf("5D SWIZZLE_128B: result=%d\n", r);
}
// Try with smaller tile dims that have strides <= 8
// Tile (8, 8) with stride [1, 16] → stride[1] = 16 > 8 STILL FAILS
// Tile (16, 1) with stride [1, 16] → stride[1] = 16 > 8
// Actually, let me check: is elementStrides in ELEMENTS or BYTES?
// If in bytes: stride[1] = 16 elements * 2 bytes = 32 → still > 8
// I think the element strides constraint is the FUNDAMENTAL issue
// For row-major (128, 16) BF16, row stride = 16 elements > 8
// There is NO way to make this work with 2D TMA for this layout
// The CUTLASS solution: they use K-major (column-major) tensors
// In K-major, the innermost dimension is K (contiguous), stride = 1
// This works because the MMA also reads K-major data
// For V in K-major: (K=128, N=16) → stride between N elements = 128 elements
// elementStrides = [1, 128] → 128 >> 8 → STILL FAILS!
// But wait: CUTLASS decomposes 128 = 8 * 16
// In 5D: (8, 16, 1, 1, 16) with strides [1, 8, 128, 128, 128]
// elementStrides = [1, 8, 16, 16, 16] → ALL ≤ 16 but not ≤ 8
// Hmm. Let me try with strides that are all <= 8
// 128 = 2^7 = 8 * 16 = 4 * 32 = ...
// (8, 16, 1, 1, 16) → strides [1, 8, 128, 128, 128]
// elementStrides = [1, 8, 16, 16, 16]
// 8 is the max, and stride[1] = 8 is OK, but stride[2] = 16 > 8
// Need: (8, 8, 2, 1, 16) → strides [1, 8, 64, 128, 128]
// elementStrides = [1, 8, 8, 16, 16] → stride[3] = 16 > 8
// (8, 8, 8, 1, 2) → strides [1, 8, 64, 512, 512]
// elementStrides = [1, 8, 8, 8, 1] → ALL <= 8!
// This would mean: dim = [8, 8, 8, 1, 2], total = 8*8*8*1*2 = 1024 elements
// But our tensor is 128*16 = 2048 elements (2x too big)
// (8, 8, 8, 1, 4) → total = 2048, strides = [1, 8, 64, 512, 2048]
// elementStrides = [1, 8, 8, 8, 4] → ALL <= 8!
// This works!
{
// 5D descriptor for (128, 16) BF16 in K-major (col-major)
// Decomposition: (8, 8, 8, 1, 4)
// Total elements: 8*8*8*1*4 = 2048 = 128*16 ✓
// elementStrides: [1, 8, 8, 8, 4] → all ≤ 8 ✓
// globalStrides (bytes): [2, 16, 128, 1024, 1024]
uint64_t gdim[] = {8, 8, 8, 1, 4};
uint64_t gstr[] = {2, 16, 128, 1024, 1024};
uint32_t tdim[] = {8, 8, 8, 1, 4};
uint32_t tstr[] = {1, 8, 64, 512, 512};
r = cuTensorMapEncodeTiled(&tma, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 5, d_aligned,
gdim, gstr, tdim, tstr,
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_NONE,
CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
printf("5D (8,8,8,1,4) strides [1,8,8,8,4]: result=%d\n", r);
}
cudaFree(d_data);
return 0;
}