P4: fix TMA load test (32-bit SMEM addrs, proper mbarrier)

This commit is contained in:
2026-05-30 08:38:55 +00:00
parent e2ecdc42d8
commit 16027018df

View File

@@ -4,11 +4,6 @@
* Creates TMA descriptors with various swizzle/OOB configs,
* launches a kernel that does cp.async.bulk.tensor.2d with each,
* and checks if the load completes (mbarrier signals) or hangs.
*
* The existing FMHA kernel's TMA loads work (they use CuTeDSL's
* TMA path which creates descriptors with swizzle). The raw CUDA
* path with NO swizzle hangs. This test identifies which field
* causes the hang.
*/
#include <cuda.h>
#include <cuda_runtime.h>
@@ -16,41 +11,43 @@
#include <cstdint>
#include <cstring>
// Maximum wait iterations before declaring a hang
#define MAX_WAIT 1000000
__global__ void tma_load_test_kernel(
const void* tma_desc_ptr, // 128-byte TMA descriptor in GMEM
void* smem_out, // SMEM buffer for TMA output (256 bytes)
int* result // GMEM: 0=pending, 1=success, -1=hang
const void* tma_desc_ptr,
int* result
) {
// Set up mbarrier in SMEM
// SMEM: mbarrier (8 bytes aligned) + output buffer (512 bytes)
__shared__ uint64_t mbar;
__shared__ uint16_t smem_out[16 * 16]; // 256 BF16 values
if (threadIdx.x == 0) {
// Initialize mbarrier with expected count = 1 (one TMA load)
asm volatile("mbarrier.init.shared.b64 [%0], 1;" :: "r"(__cvta_generic_to_shared(&mbar)));
uint32_t mbar_addr = __cvta_generic_to_shared(&mbar);
asm volatile("mbarrier.init.shared.b64 [%0], 1;" :: "r"(mbar_addr));
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
// Only thread 0 issues TMA
if (threadIdx.x == 0) {
// TMA load: 16x16 BF16 tile = 512 bytes
// cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
uint32_t smem_addr = __cvta_generic_to_shared(smem_out);
uint32_t mbar_addr = __cvta_generic_to_shared(&mbar);
// Issue TMA load: 16x16 BF16 = 512 bytes
asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes "
"[%0], [%1, {%3, %4}], [%2];"
:: "r"(__cvta_generic_to_shared(smem_out)),
:: "r"(smem_addr),
"l"(tma_desc_ptr),
"r"(__cvta_generic_to_shared(&mbar)),
"r"(mbar_addr),
"r"(0), // coord row = 0
"r"(0) // coord col = 0
);
}
__syncthreads();
// Wait for mbarrier to complete (TMA arrival)
// Wait for mbarrier
if (threadIdx.x == 0) {
uint32_t mbar_addr = __cvta_generic_to_shared(&mbar);
int waited = 0;
int arrived = 0;
while (waited < MAX_WAIT) {
@@ -62,16 +59,12 @@ __global__ void tma_load_test_kernel(
"selp.b32 %1, 1, 0, p;\n\t"
"}"
: "=r"(state)
: "r"(__cvta_generic_to_shared(&mbar))
: "r"(mbar_addr)
);
if (state) { arrived = 1; break; }
waited++;
}
if (arrived) {
*result = 1; // success
} else {
*result = -1; // hang
}
*result = arrived ? 1 : -1;
}
}
@@ -104,19 +97,14 @@ int main() {
const int ROWS = 128;
const int COLS = 16;
const size_t DATA_SIZE = ROWS * COLS * 2;
const size_t SMEM_SIZE = 512; // 16x16 BF16 = 512 bytes
// Allocate source data
void* d_data;
cudaMalloc(&d_data, DATA_SIZE);
cudaMemset(d_data, 1, DATA_SIZE); // Fill with non-zero data
cudaMemset(d_data, 1, DATA_SIZE);
// Allocate result
int* d_result;
cudaMalloc(&d_result, sizeof(int));
cudaMemset(d_result, 0, sizeof(int));
// Test configs
struct { const char* name; int swizzle; int oob; } configs[] = {
{"NO swizzle, OOB_NONE", 0, 0},
{"SWIZZLE_128B, OOB_NONE", 1, 0},
@@ -127,23 +115,15 @@ int main() {
for (int i = 0; i < 4; i++) {
printf("Testing: %s\n", configs[i].name);
// Create descriptor and copy to GMEM
CUtensorMap desc = create_descriptor(d_data, configs[i].swizzle, configs[i].oob);
void* d_desc;
cudaMalloc(&d_desc, sizeof(CUtensorMap));
cudaMemcpy(d_desc, &desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
// Allocate SMEM output
void* d_smem_out;
cudaMalloc(&d_smem_out, SMEM_SIZE);
// Reset result
cudaMemset(d_result, 0, sizeof(int));
// Launch with timeout
tma_load_test_kernel<<<1, 32, SMEM_SIZE + 64>>>(d_desc, d_smem_out, d_result);
tma_load_test_kernel<<<1, 32>>>(d_desc, d_result);
// Check result with a short timeout on host
cudaError_t err = cudaDeviceSynchronize();
int h_result;
cudaMemcpy(&h_result, d_result, sizeof(int), cudaMemcpyDeviceToHost);
@@ -159,7 +139,6 @@ int main() {
}
cudaFree(d_desc);
cudaFree(d_smem_out);
}
cudaFree(d_data);