From a766b488c2df1d570feeb758f0c3ee433084c762 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 29 May 2026 19:25:01 +0000 Subject: [PATCH] =?UTF-8?q?test:=20minimal=20TMA=20diagnostic=20=E2=80=94?= =?UTF-8?q?=20isolate=20multi-warp=20TMA=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_tma_minimal.cu | 318 +++++++++++++++++++++++++++++++++ 1 file changed, 318 insertions(+) create mode 100644 tests/unit/test_tma_minimal.cu diff --git a/tests/unit/test_tma_minimal.cu b/tests/unit/test_tma_minimal.cu new file mode 100644 index 00000000..b3fc0830 --- /dev/null +++ b/tests/unit/test_tma_minimal.cu @@ -0,0 +1,318 @@ +/** + * MINIMAL TMA diagnostic — no MMA, no TMEM. + * Just: TMA load K sub-tile → verify SMEM contents match GMEM. + * + * Tests with 32, 64, 128, 192 threads to isolate the multi-warp TMA bug. + * Build with -DNUM_THREADS=128 etc. + */ + +#include +#include +#include +#include +#include +#include + +#ifndef HD_VAL +#define HD_VAL 16 +#endif + +#ifndef NUM_THREADS +#define NUM_THREADS 32 +#endif + +typedef unsigned short bf16_t; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = HD_VAL; +constexpr int SK = 128; +constexpr int TILE_ROWS = 128; +constexpr int TILE_COLS = 16; // MMA K-tile +constexpr int TILE_BYTES = TILE_ROWS * TILE_COLS * 2; // 4096 + +// TMA helpers +__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar, uint32_t count) { + asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "r"(smem_mbar), "r"(count)); +} + +__device__ __forceinline__ void tma_mbarrier_arrive_expect_tx(uint32_t smem_mbar, uint32_t tx_bytes) { + asm volatile("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;" + :: "r"(smem_mbar), "r"(tx_bytes) : "memory"); +} + +__device__ __forceinline__ void tma_load_2d( + uint32_t smem_dst, uint64_t tma_desc, uint32_t smem_mbar, + int coord_x, int coord_y +) { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes " + "[%0], [%1, {%3, %4}], [%2];" + :: "r"(smem_dst), "l"(tma_desc), "r"(smem_mbar), "r"(coord_x), "r"(coord_y) + : "memory" + ); +} + +__device__ __forceinline__ void tma_mbarrier_wait(uint32_t smem_mbar, int phase) { + asm volatile( + "{\n\t" + ".reg .pred P1;\n\t" + "LAB_WAIT:" + "mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n\t" + "@P1 bra.uni DONE;\n\t" + "bra.uni LAB_WAIT;\n\t" + "DONE:\n\t" + "}" + :: "r"(smem_mbar), "r"(phase), "r"(0x989680) + : "memory" + ); +} + +// Also test commit_group / wait_group approach +__device__ __forceinline__ void cp_async_bulk_commit_group() { + asm volatile("cp.async.bulk.commit_group;" ::: "memory"); +} + +template +__device__ __forceinline__ void cp_async_bulk_wait_group() { + asm volatile("cp.async.bulk.wait_group %0;" :: "n"(N) : "memory"); +} + +// Approach 1: mbarrier +__global__ void __launch_bounds__(NUM_THREADS) +test_tma_mbarrier_kernel( + int* __restrict__ mismatches, + bf16_t* __restrict__ verify_buf, // copy of what TMA loaded + CUtensorMap* __restrict__ tma_k, + int s_k +) { + const int tid = threadIdx.x; + const int lane = tid % 32; + + extern __shared__ __align__(128) char sbuf[]; + bf16_t* sTmaBuf = (bf16_t*)(sbuf); + uint64_t* sMbar = (uint64_t*)(sbuf + ((TILE_BYTES + 127) & ~127)); + + if (tid == 0) { + tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1); + asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); + } + __syncthreads(); + + const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + int phase = 0; + + // Issue TMA: only thread 0 + if (tid == 0) { + tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k, mbar_addr, 0, 0); + tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_BYTES); + } + // ALL threads wait + tma_mbarrier_wait(mbar_addr, phase); + __syncthreads(); + + // Copy to verify buffer and check + int bad = 0; + for (int i = tid; i < TILE_ROWS * TILE_COLS; i += NUM_THREADS) { + verify_buf[i] = sTmaBuf[i]; + } +} + +// Approach 2: commit_group / wait_group (no mbarrier) +__global__ void __launch_bounds__(NUM_THREADS) +test_tma_commit_group_kernel( + int* __restrict__ mismatches, + bf16_t* __restrict__ verify_buf, + CUtensorMap* __restrict__ tma_k, + int s_k +) { + const int tid = threadIdx.x; + const int lane = tid % 32; + + extern __shared__ __align__(128) char sbuf[]; + bf16_t* sTmaBuf = (bf16_t*)(sbuf); + + // Issue TMA: only thread 0 + if (tid == 0) { + // TMA with mbarrier::complete_tx — we still need a mbarrier for the TMA instruction + // But we can use commit_group/wait_group for the synchronization + // Actually, cp.async.bulk.tensor REQUIRES a mbarrier. commit_group/wait_group + // is for cp.async.bulk (non-tensor). Let's use a different approach: + // Use mbarrier just for the TMA completion signal, but wait via a simpler pattern. + } + __syncthreads(); + + // Actually, the PTX spec says cp.async.bulk.tensor.2d REQUIRES mbarrier. + // commit_group/wait_group is for cp.async.bulk (non-tensor) only. + // So the mbarrier approach is the ONLY way for TMA. + // The bug must be in how we USE the mbarrier, not in the approach itself. + // Let's instead try: only 1 thread issues TMA AND waits, then signals others. +} + +// Approach 3: Single-thread TMA issue + wait, then signal +__global__ void __launch_bounds__(NUM_THREADS) +test_tma_single_wait_kernel( + int* __restrict__ mismatches, + bf16_t* __restrict__ verify_buf, + CUtensorMap* __restrict__ tma_k, + int s_k +) { + const int tid = threadIdx.x; + const int lane = tid % 32; + + extern __shared__ __align__(128) char sbuf[]; + bf16_t* sTmaBuf = (bf16_t*)(sbuf); + uint64_t* sMbar = (uint64_t*)(sbuf + ((TILE_BYTES + 127) & ~127)); + volatile int* sFlag = (volatile int*)(sbuf + ((TILE_BYTES + 127) & ~127) + 16); + + if (tid == 0) { + tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1); + asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); + *sFlag = 0; + } + __syncthreads(); + + // ONLY thread 0 issues TMA AND waits for completion + if (tid == 0) { + const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + int phase = 0; + tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k, mbar_addr, 0, 0); + tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_BYTES); + tma_mbarrier_wait(mbar_addr, phase); + // Signal completion to other threads + asm volatile("fence.sc.gpu;" ::: "memory"); + *sFlag = 1; + } + + // Other threads spin-wait + if (tid != 0) { + while (*sFlag == 0) {} + } + __syncthreads(); + + // Copy to verify buffer + for (int i = tid; i < TILE_ROWS * TILE_COLS; i += NUM_THREADS) { + verify_buf[i] = sTmaBuf[i]; + } +} + +inline bool create_tma_desc_2d_bf16( + CUtensorMap* out, const void* gmem_ptr, + uint64_t rows, uint64_t cols, + uint32_t tile_rows, uint32_t tile_cols +) { + uint64_t global_dim[] = {cols, rows}; + uint64_t global_str[] = {cols * 2}; + uint32_t tile_dim[] = {tile_cols, tile_rows}; + uint32_t tile_str[] = {1, 1}; + CUresult res = cuTensorMapEncodeTiled( + out, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, + const_cast(gmem_ptr), + global_dim, global_str, tile_dim, tile_str, + CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE + ); + if (res != CUDA_SUCCESS) { + fprintf(stderr, "cuTensorMapEncodeTiled failed: %d\n", (int)res); + return false; + } + // Driver workaround for CUDA <= 13.1 + small tensors + int dv = 0; cudaDriverGetVersion(&dv); + size_t total = rows * cols * 2; + if (dv <= 13010 && total < 131072) { + reinterpret_cast(out)[1] &= ~(1ULL << 21); + } + return true; +} + +int main() { + printf("=== Minimal TMA diagnostic (HD=%d, NUM_THREADS=%d) ===\n", HD, NUM_THREADS); + + // Create K data: (SK, HD) = (128, 16) for the first sub-tile + const int total = SK * HD; + bf16_t* h_k = (bf16_t*)calloc(total, sizeof(bf16_t)); + srand(42); + for (int i = 0; i < total; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + + // We'll load just the first (128, 16) sub-tile of K + bf16_t* h_k_sub = (bf16_t*)calloc(TILE_ROWS * TILE_COLS, sizeof(bf16_t)); + for (int r = 0; r < TILE_ROWS && r < SK; r++) + for (int c = 0; c < TILE_COLS && c < HD; c++) + h_k_sub[r * TILE_COLS + c] = h_k[r * HD + c]; + + bf16_t *d_k, *d_k_sub, *d_verify; + cudaMalloc(&d_k, total * sizeof(bf16_t)); + cudaMalloc(&d_k_sub, TILE_ROWS * TILE_COLS * sizeof(bf16_t)); + cudaMalloc(&d_verify, TILE_ROWS * TILE_COLS * sizeof(bf16_t)); + cudaMemcpy(d_k, h_k, total * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_k_sub, h_k_sub, TILE_ROWS * TILE_COLS * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemset(d_verify, 0, TILE_ROWS * TILE_COLS * sizeof(bf16_t)); + + // TMA descriptor for the sub-tile + CUtensorMap tma_k; CUtensorMap* d_tma_k; + if (!create_tma_desc_2d_bf16(&tma_k, d_k_sub, TILE_ROWS, TILE_COLS, TILE_ROWS, TILE_COLS)) { + printf("TMA desc FAILED\n"); return 1; + } + cudaMalloc(&d_tma_k, sizeof(CUtensorMap)); + cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + + int* d_mismatch; + cudaMalloc(&d_mismatch, sizeof(int)); + cudaMemset(d_mismatch, 0, sizeof(int)); + + // --- Test 1: mbarrier approach --- + printf("\n--- Test 1: mbarrier (all threads wait) ---\n"); + { + size_t smem = TILE_BYTES + 128 + 16 + 4; // TMA buf + mbar + flag + padding + cudaMemset(d_verify, 0, TILE_ROWS * TILE_COLS * sizeof(bf16_t)); + test_tma_mbarrier_kernel<<<1, NUM_THREADS, smem>>>(d_mismatch, d_verify, d_tma_k, SK); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf(" CUDA ERROR: %s\n", cudaGetErrorString(err)); + } else { + bf16_t* h_verify = (bf16_t*)malloc(TILE_ROWS * TILE_COLS * sizeof(bf16_t)); + cudaMemcpy(h_verify, d_verify, TILE_ROWS * TILE_COLS * sizeof(bf16_t), cudaMemcpyDeviceToHost); + int bad = 0; + for (int i = 0; i < TILE_ROWS * TILE_COLS; i++) { + if (h_verify[i] != h_k_sub[i]) bad++; + } + printf(" Mismatches: %d / %d\n", bad, TILE_ROWS * TILE_COLS); + if (bad > 0 && bad < 10) { + for (int i = 0; i < TILE_ROWS * TILE_COLS; i++) { + if (h_verify[i] != h_k_sub[i]) + printf(" [%d]: expected %u got %u\n", i, (unsigned)h_k_sub[i], (unsigned)h_verify[i]); + } + } + free(h_verify); + } + } + + // --- Test 3: single-thread wait --- + printf("\n--- Test 3: single-thread TMA issue+wait, flag signal ---\n"); + { + size_t smem = TILE_BYTES + 128 + 16 + 4; + cudaMemset(d_verify, 0, TILE_ROWS * TILE_COLS * sizeof(bf16_t)); + test_tma_single_wait_kernel<<<1, NUM_THREADS, smem>>>(d_mismatch, d_verify, d_tma_k, SK); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf(" CUDA ERROR: %s\n", cudaGetErrorString(err)); + } else { + bf16_t* h_verify = (bf16_t*)malloc(TILE_ROWS * TILE_COLS * sizeof(bf16_t)); + cudaMemcpy(h_verify, d_verify, TILE_ROWS * TILE_COLS * sizeof(bf16_t), cudaMemcpyDeviceToHost); + int bad = 0; + for (int i = 0; i < TILE_ROWS * TILE_COLS; i++) { + if (h_verify[i] != h_k_sub[i]) bad++; + } + printf(" Mismatches: %d / %d\n", bad, TILE_ROWS * TILE_COLS); + free(h_verify); + } + } + + cudaFree(d_k); cudaFree(d_k_sub); cudaFree(d_verify); cudaFree(d_tma_k); cudaFree(d_mismatch); + free(h_k); free(h_k_sub); + printf("\nDone.\n"); + return 0; +}