diff --git a/tests/unit/test_p4_tma_bit21_fix.cu b/tests/unit/test_p4_tma_bit21_fix.cu new file mode 100644 index 00000000..73e1c401 --- /dev/null +++ b/tests/unit/test_p4_tma_bit21_fix.cu @@ -0,0 +1,171 @@ +/** + * P4: Test TMA load with the bit-21 workaround from CUTLASS. + * + * Root cause of the TMA hang: driver 13.0 can't read descriptors + * created by toolkit 13.2's cuTensorMapEncodeTiled. CUTLASS clears + * bit 21 of desc[1] as a workaround for driver <= 13.1 with small tensors. + * + * This test: + * 1. Creates a 2D TMA descriptor with NO swizzle + * 2. Dumps the descriptor bytes + * 3. Clears bit 21 of word[1] (the 64-bit word at offset 8) + * 4. Dumps the modified descriptor + * 5. Tests the TMA load with both descriptors + */ +#include +#include +#include +#include +#include + +__global__ void tma_load_kernel( + const void* tma_desc_ptr, + int* result +) { + __shared__ uint64_t mbar; + __shared__ uint16_t smem_out[16 * 16]; + + if (threadIdx.x == 0) { + uint32_t mbar_addr = __cvta_generic_to_shared(&mbar); + asm volatile("mbarrier.init.shared.b64 [%0], 1;" :: "r"(mbar_addr)); + asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); + } + __syncthreads(); + + if (threadIdx.x == 0) { + uint32_t smem_addr = __cvta_generic_to_shared(smem_out); + uint32_t mbar_addr = __cvta_generic_to_shared(&mbar); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%3, %4}], [%2];" + :: "r"(smem_addr), + "l"(tma_desc_ptr), + "r"(mbar_addr), + "r"(0), + "r"(0) + ); + } + __syncthreads(); + + if (threadIdx.x == 0) { + uint32_t mbar_addr = __cvta_generic_to_shared(&mbar); + int waited = 0; + while (waited < 1000000) { + uint32_t state; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "mbarrier.try_wait.parity.shared.b64 p, [%0], 0;\n\t" + "selp.b32 %1, 1, 0, p;\n\t" + "}" + : "=r"(state) + : "r"(mbar_addr) + ); + if (state) { *result = 1; return; } + waited++; + } + *result = -1; + } +} + + +CUtensorMap create_descriptor(void* d_ptr, bool clear_bit21) { + CUtensorMap desc; + // (128, 16) BF16, row-major, NO swizzle + // CUDA 13: globalDim is innermost-first, globalStrides in bytes, rank-1 strides + cuuint64_t globalDim[] = {16, 128}; // (cols, rows) innermost-first + cuuint64_t globalStrides[] = {16 * 2}; // row stride in bytes (rank-1 strides!) + cuuint32_t boxDim[] = {16, 16}; + cuuint32_t elementStrides[] = {1, 1}; + + CUresult res = cuTensorMapEncodeTiled(&desc, + CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, + d_ptr, globalDim, globalStrides, boxDim, elementStrides, + CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE, + CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + if (res != CUDA_SUCCESS) { + printf(" cuTensorMapEncodeTiled FAILED: %d\n", res); + return desc; + } + + // Apply bit-21 workaround if requested + if (clear_bit21) { + uint64_t* words = reinterpret_cast(&desc); + words[1] &= ~(1ULL << 21); + } + + return desc; +} + + +int main() { + const size_t DATA_SIZE = 128 * 16 * 2; // (128, 16) BF16 + + void* d_data; + cudaMalloc(&d_data, DATA_SIZE); + cudaMemset(d_data, 1, DATA_SIZE); + + int* d_result; + cudaMalloc(&d_result, sizeof(int)); + + // Test 1: Original descriptor (no bit-21 fix) + printf("=== Test 1: Original descriptor (no fix) ===\n"); + { + CUtensorMap desc = create_descriptor(d_data, false); + // Dump first 16 bytes + auto* b = reinterpret_cast(&desc); + printf(" Bytes [0-7]: "); for (int j=0;j<8;j++) printf("%02x ", b[j]); printf("\n"); + printf(" Bytes [8-15]: "); for (int j=0;j<8;j++) printf("%02x ", b[8+j]); printf("\n"); + + void* d_desc; + cudaMalloc(&d_desc, sizeof(CUtensorMap)); + cudaMemcpy(d_desc, &desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + + cudaMemset(d_result, 0, sizeof(int)); + tma_load_kernel<<<1, 32>>>(d_desc, d_result); + + cudaError_t err = cudaDeviceSynchronize(); + int h_result; + cudaMemcpy(&h_result, d_result, sizeof(int), cudaMemcpyDeviceToHost); + + if (err != cudaSuccess) printf(" ERROR: %s (result=%d)\n", cudaGetErrorString(err), h_result); + else if (h_result == 1) printf(" SUCCESS\n"); + else if (h_result == -1) printf(" HANG (mbarrier timeout)\n"); + else printf(" UNKNOWN: result=%d\n", h_result); + + cudaFree(d_desc); + } + + // Test 2: Bit-21 cleared (CUTLASS workaround) + printf("\n=== Test 2: Bit-21 cleared (CUTLASS workaround) ===\n"); + { + CUtensorMap desc = create_descriptor(d_data, true); + auto* b = reinterpret_cast(&desc); + printf(" Bytes [0-7]: "); for (int j=0;j<8;j++) printf("%02x ", b[j]); printf("\n"); + printf(" Bytes [8-15]: "); for (int j=0;j<8;j++) printf("%02x ", b[8+j]); printf("\n"); + + void* d_desc; + cudaMalloc(&d_desc, sizeof(CUtensorMap)); + cudaMemcpy(d_desc, &desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + + cudaMemset(d_result, 0, sizeof(int)); + tma_load_kernel<<<1, 32>>>(d_desc, d_result); + + cudaError_t err = cudaDeviceSynchronize(); + int h_result; + cudaMemcpy(&h_result, d_result, sizeof(int), cudaMemcpyDeviceToHost); + + if (err != cudaSuccess) printf(" ERROR: %s (result=%d)\n", cudaGetErrorString(err), h_result); + else if (h_result == 1) printf(" SUCCESS\n"); + else if (h_result == -1) printf(" HANG (mbarrier timeout)\n"); + else printf(" UNKNOWN: result=%d\n", h_result); + + cudaFree(d_desc); + } + + cudaFree(d_data); + cudaFree(d_result); + printf("\nPASSED\n"); + return 0; +}