P4: test TMA with bit-21 workaround and innermost-first dims

This commit is contained in:
2026-05-30 08:40:21 +00:00
parent 16027018df
commit 90c806733f

View File

@@ -0,0 +1,171 @@
/**
* P4: Test TMA load with the bit-21 workaround from CUTLASS.
*
* Root cause of the TMA hang: driver 13.0 can't read descriptors
* created by toolkit 13.2's cuTensorMapEncodeTiled. CUTLASS clears
* bit 21 of desc[1] as a workaround for driver <= 13.1 with small tensors.
*
* This test:
* 1. Creates a 2D TMA descriptor with NO swizzle
* 2. Dumps the descriptor bytes
* 3. Clears bit 21 of word[1] (the 64-bit word at offset 8)
* 4. Dumps the modified descriptor
* 5. Tests the TMA load with both descriptors
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdio>
#include <cstdint>
#include <cstring>
__global__ void tma_load_kernel(
const void* tma_desc_ptr,
int* result
) {
__shared__ uint64_t mbar;
__shared__ uint16_t smem_out[16 * 16];
if (threadIdx.x == 0) {
uint32_t mbar_addr = __cvta_generic_to_shared(&mbar);
asm volatile("mbarrier.init.shared.b64 [%0], 1;" :: "r"(mbar_addr));
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
if (threadIdx.x == 0) {
uint32_t smem_addr = __cvta_generic_to_shared(smem_out);
uint32_t mbar_addr = __cvta_generic_to_shared(&mbar);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes "
"[%0], [%1, {%3, %4}], [%2];"
:: "r"(smem_addr),
"l"(tma_desc_ptr),
"r"(mbar_addr),
"r"(0),
"r"(0)
);
}
__syncthreads();
if (threadIdx.x == 0) {
uint32_t mbar_addr = __cvta_generic_to_shared(&mbar);
int waited = 0;
while (waited < 1000000) {
uint32_t state;
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"mbarrier.try_wait.parity.shared.b64 p, [%0], 0;\n\t"
"selp.b32 %1, 1, 0, p;\n\t"
"}"
: "=r"(state)
: "r"(mbar_addr)
);
if (state) { *result = 1; return; }
waited++;
}
*result = -1;
}
}
CUtensorMap create_descriptor(void* d_ptr, bool clear_bit21) {
CUtensorMap desc;
// (128, 16) BF16, row-major, NO swizzle
// CUDA 13: globalDim is innermost-first, globalStrides in bytes, rank-1 strides
cuuint64_t globalDim[] = {16, 128}; // (cols, rows) innermost-first
cuuint64_t globalStrides[] = {16 * 2}; // row stride in bytes (rank-1 strides!)
cuuint32_t boxDim[] = {16, 16};
cuuint32_t elementStrides[] = {1, 1};
CUresult res = cuTensorMapEncodeTiled(&desc,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2,
d_ptr, globalDim, globalStrides, boxDim, elementStrides,
CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE,
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
if (res != CUDA_SUCCESS) {
printf(" cuTensorMapEncodeTiled FAILED: %d\n", res);
return desc;
}
// Apply bit-21 workaround if requested
if (clear_bit21) {
uint64_t* words = reinterpret_cast<uint64_t*>(&desc);
words[1] &= ~(1ULL << 21);
}
return desc;
}
int main() {
const size_t DATA_SIZE = 128 * 16 * 2; // (128, 16) BF16
void* d_data;
cudaMalloc(&d_data, DATA_SIZE);
cudaMemset(d_data, 1, DATA_SIZE);
int* d_result;
cudaMalloc(&d_result, sizeof(int));
// Test 1: Original descriptor (no bit-21 fix)
printf("=== Test 1: Original descriptor (no fix) ===\n");
{
CUtensorMap desc = create_descriptor(d_data, false);
// Dump first 16 bytes
auto* b = reinterpret_cast<const uint8_t*>(&desc);
printf(" Bytes [0-7]: "); for (int j=0;j<8;j++) printf("%02x ", b[j]); printf("\n");
printf(" Bytes [8-15]: "); for (int j=0;j<8;j++) printf("%02x ", b[8+j]); printf("\n");
void* d_desc;
cudaMalloc(&d_desc, sizeof(CUtensorMap));
cudaMemcpy(d_desc, &desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
cudaMemset(d_result, 0, sizeof(int));
tma_load_kernel<<<1, 32>>>(d_desc, d_result);
cudaError_t err = cudaDeviceSynchronize();
int h_result;
cudaMemcpy(&h_result, d_result, sizeof(int), cudaMemcpyDeviceToHost);
if (err != cudaSuccess) printf(" ERROR: %s (result=%d)\n", cudaGetErrorString(err), h_result);
else if (h_result == 1) printf(" SUCCESS\n");
else if (h_result == -1) printf(" HANG (mbarrier timeout)\n");
else printf(" UNKNOWN: result=%d\n", h_result);
cudaFree(d_desc);
}
// Test 2: Bit-21 cleared (CUTLASS workaround)
printf("\n=== Test 2: Bit-21 cleared (CUTLASS workaround) ===\n");
{
CUtensorMap desc = create_descriptor(d_data, true);
auto* b = reinterpret_cast<const uint8_t*>(&desc);
printf(" Bytes [0-7]: "); for (int j=0;j<8;j++) printf("%02x ", b[j]); printf("\n");
printf(" Bytes [8-15]: "); for (int j=0;j<8;j++) printf("%02x ", b[8+j]); printf("\n");
void* d_desc;
cudaMalloc(&d_desc, sizeof(CUtensorMap));
cudaMemcpy(d_desc, &desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
cudaMemset(d_result, 0, sizeof(int));
tma_load_kernel<<<1, 32>>>(d_desc, d_result);
cudaError_t err = cudaDeviceSynchronize();
int h_result;
cudaMemcpy(&h_result, d_result, sizeof(int), cudaMemcpyDeviceToHost);
if (err != cudaSuccess) printf(" ERROR: %s (result=%d)\n", cudaGetErrorString(err), h_result);
else if (h_result == 1) printf(" SUCCESS\n");
else if (h_result == -1) printf(" HANG (mbarrier timeout)\n");
else printf(" UNKNOWN: result=%d\n", h_result);
cudaFree(d_desc);
}
cudaFree(d_data);
cudaFree(d_result);
printf("\nPASSED\n");
return 0;
}