P4: test TMA with bit-21 workaround and innermost-first dims
This commit is contained in:
171
tests/unit/test_p4_tma_bit21_fix.cu
Normal file
171
tests/unit/test_p4_tma_bit21_fix.cu
Normal file
@@ -0,0 +1,171 @@
|
||||
/**
|
||||
* P4: Test TMA load with the bit-21 workaround from CUTLASS.
|
||||
*
|
||||
* Root cause of the TMA hang: driver 13.0 can't read descriptors
|
||||
* created by toolkit 13.2's cuTensorMapEncodeTiled. CUTLASS clears
|
||||
* bit 21 of desc[1] as a workaround for driver <= 13.1 with small tensors.
|
||||
*
|
||||
* This test:
|
||||
* 1. Creates a 2D TMA descriptor with NO swizzle
|
||||
* 2. Dumps the descriptor bytes
|
||||
* 3. Clears bit 21 of word[1] (the 64-bit word at offset 8)
|
||||
* 4. Dumps the modified descriptor
|
||||
* 5. Tests the TMA load with both descriptors
|
||||
*/
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
__global__ void tma_load_kernel(
|
||||
const void* tma_desc_ptr,
|
||||
int* result
|
||||
) {
|
||||
__shared__ uint64_t mbar;
|
||||
__shared__ uint16_t smem_out[16 * 16];
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
uint32_t mbar_addr = __cvta_generic_to_shared(&mbar);
|
||||
asm volatile("mbarrier.init.shared.b64 [%0], 1;" :: "r"(mbar_addr));
|
||||
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
uint32_t smem_addr = __cvta_generic_to_shared(smem_out);
|
||||
uint32_t mbar_addr = __cvta_generic_to_shared(&mbar);
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes "
|
||||
"[%0], [%1, {%3, %4}], [%2];"
|
||||
:: "r"(smem_addr),
|
||||
"l"(tma_desc_ptr),
|
||||
"r"(mbar_addr),
|
||||
"r"(0),
|
||||
"r"(0)
|
||||
);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
uint32_t mbar_addr = __cvta_generic_to_shared(&mbar);
|
||||
int waited = 0;
|
||||
while (waited < 1000000) {
|
||||
uint32_t state;
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"mbarrier.try_wait.parity.shared.b64 p, [%0], 0;\n\t"
|
||||
"selp.b32 %1, 1, 0, p;\n\t"
|
||||
"}"
|
||||
: "=r"(state)
|
||||
: "r"(mbar_addr)
|
||||
);
|
||||
if (state) { *result = 1; return; }
|
||||
waited++;
|
||||
}
|
||||
*result = -1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CUtensorMap create_descriptor(void* d_ptr, bool clear_bit21) {
|
||||
CUtensorMap desc;
|
||||
// (128, 16) BF16, row-major, NO swizzle
|
||||
// CUDA 13: globalDim is innermost-first, globalStrides in bytes, rank-1 strides
|
||||
cuuint64_t globalDim[] = {16, 128}; // (cols, rows) innermost-first
|
||||
cuuint64_t globalStrides[] = {16 * 2}; // row stride in bytes (rank-1 strides!)
|
||||
cuuint32_t boxDim[] = {16, 16};
|
||||
cuuint32_t elementStrides[] = {1, 1};
|
||||
|
||||
CUresult res = cuTensorMapEncodeTiled(&desc,
|
||||
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2,
|
||||
d_ptr, globalDim, globalStrides, boxDim, elementStrides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
||||
|
||||
if (res != CUDA_SUCCESS) {
|
||||
printf(" cuTensorMapEncodeTiled FAILED: %d\n", res);
|
||||
return desc;
|
||||
}
|
||||
|
||||
// Apply bit-21 workaround if requested
|
||||
if (clear_bit21) {
|
||||
uint64_t* words = reinterpret_cast<uint64_t*>(&desc);
|
||||
words[1] &= ~(1ULL << 21);
|
||||
}
|
||||
|
||||
return desc;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
const size_t DATA_SIZE = 128 * 16 * 2; // (128, 16) BF16
|
||||
|
||||
void* d_data;
|
||||
cudaMalloc(&d_data, DATA_SIZE);
|
||||
cudaMemset(d_data, 1, DATA_SIZE);
|
||||
|
||||
int* d_result;
|
||||
cudaMalloc(&d_result, sizeof(int));
|
||||
|
||||
// Test 1: Original descriptor (no bit-21 fix)
|
||||
printf("=== Test 1: Original descriptor (no fix) ===\n");
|
||||
{
|
||||
CUtensorMap desc = create_descriptor(d_data, false);
|
||||
// Dump first 16 bytes
|
||||
auto* b = reinterpret_cast<const uint8_t*>(&desc);
|
||||
printf(" Bytes [0-7]: "); for (int j=0;j<8;j++) printf("%02x ", b[j]); printf("\n");
|
||||
printf(" Bytes [8-15]: "); for (int j=0;j<8;j++) printf("%02x ", b[8+j]); printf("\n");
|
||||
|
||||
void* d_desc;
|
||||
cudaMalloc(&d_desc, sizeof(CUtensorMap));
|
||||
cudaMemcpy(d_desc, &desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
|
||||
cudaMemset(d_result, 0, sizeof(int));
|
||||
tma_load_kernel<<<1, 32>>>(d_desc, d_result);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
int h_result;
|
||||
cudaMemcpy(&h_result, d_result, sizeof(int), cudaMemcpyDeviceToHost);
|
||||
|
||||
if (err != cudaSuccess) printf(" ERROR: %s (result=%d)\n", cudaGetErrorString(err), h_result);
|
||||
else if (h_result == 1) printf(" SUCCESS\n");
|
||||
else if (h_result == -1) printf(" HANG (mbarrier timeout)\n");
|
||||
else printf(" UNKNOWN: result=%d\n", h_result);
|
||||
|
||||
cudaFree(d_desc);
|
||||
}
|
||||
|
||||
// Test 2: Bit-21 cleared (CUTLASS workaround)
|
||||
printf("\n=== Test 2: Bit-21 cleared (CUTLASS workaround) ===\n");
|
||||
{
|
||||
CUtensorMap desc = create_descriptor(d_data, true);
|
||||
auto* b = reinterpret_cast<const uint8_t*>(&desc);
|
||||
printf(" Bytes [0-7]: "); for (int j=0;j<8;j++) printf("%02x ", b[j]); printf("\n");
|
||||
printf(" Bytes [8-15]: "); for (int j=0;j<8;j++) printf("%02x ", b[8+j]); printf("\n");
|
||||
|
||||
void* d_desc;
|
||||
cudaMalloc(&d_desc, sizeof(CUtensorMap));
|
||||
cudaMemcpy(d_desc, &desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
|
||||
cudaMemset(d_result, 0, sizeof(int));
|
||||
tma_load_kernel<<<1, 32>>>(d_desc, d_result);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
int h_result;
|
||||
cudaMemcpy(&h_result, d_result, sizeof(int), cudaMemcpyDeviceToHost);
|
||||
|
||||
if (err != cudaSuccess) printf(" ERROR: %s (result=%d)\n", cudaGetErrorString(err), h_result);
|
||||
else if (h_result == 1) printf(" SUCCESS\n");
|
||||
else if (h_result == -1) printf(" HANG (mbarrier timeout)\n");
|
||||
else printf(" UNKNOWN: result=%d\n", h_result);
|
||||
|
||||
cudaFree(d_desc);
|
||||
}
|
||||
|
||||
cudaFree(d_data);
|
||||
cudaFree(d_result);
|
||||
printf("\nPASSED\n");
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user