Files
nvfp4-megamoe-kernel/tests/unit/test_tma_minimal.cu

319 lines
11 KiB
Plaintext

/**
* MINIMAL TMA diagnostic — no MMA, no TMEM.
* Just: TMA load K sub-tile → verify SMEM contents match GMEM.
*
* Tests with 32, 64, 128, 192 threads to isolate the multi-warp TMA bug.
* Build with -DNUM_THREADS=128 etc.
*/
#include <cuda_runtime.h>
#include <cuda.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#ifndef HD_VAL
#define HD_VAL 16
#endif
#ifndef NUM_THREADS
#define NUM_THREADS 192
#endif
typedef unsigned short bf16_t;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = HD_VAL;
constexpr int SK = 128;
constexpr int TILE_ROWS = 128;
constexpr int TILE_COLS = 16; // MMA K-tile
constexpr int TILE_BYTES = TILE_ROWS * TILE_COLS * 2; // 4096
// TMA helpers
__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar, uint32_t count) {
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "r"(smem_mbar), "r"(count));
}
__device__ __forceinline__ void tma_mbarrier_arrive_expect_tx(uint32_t smem_mbar, uint32_t tx_bytes) {
asm volatile("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;"
:: "r"(smem_mbar), "r"(tx_bytes) : "memory");
}
__device__ __forceinline__ void tma_load_2d(
uint32_t smem_dst, uint64_t tma_desc, uint32_t smem_mbar,
int coord_x, int coord_y
) {
asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes "
"[%0], [%1, {%3, %4}], [%2];"
:: "r"(smem_dst), "l"(tma_desc), "r"(smem_mbar), "r"(coord_x), "r"(coord_y)
: "memory"
);
}
__device__ __forceinline__ void tma_mbarrier_wait(uint32_t smem_mbar, int phase) {
asm volatile(
"{\n\t"
".reg .pred P1;\n\t"
"LAB_WAIT:"
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n\t"
"@P1 bra.uni DONE;\n\t"
"bra.uni LAB_WAIT;\n\t"
"DONE:\n\t"
"}"
:: "r"(smem_mbar), "r"(phase), "r"(0x989680)
: "memory"
);
}
// Also test commit_group / wait_group approach
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;" ::: "memory");
}
template<int N>
__device__ __forceinline__ void cp_async_bulk_wait_group() {
asm volatile("cp.async.bulk.wait_group %0;" :: "n"(N) : "memory");
}
// Approach 1: mbarrier
__global__ void __launch_bounds__(NUM_THREADS)
test_tma_mbarrier_kernel(
int* __restrict__ mismatches,
bf16_t* __restrict__ verify_buf, // copy of what TMA loaded
CUtensorMap* __restrict__ tma_k,
int s_k
) {
const int tid = threadIdx.x;
const int lane = tid % 32;
extern __shared__ __align__(128) char sbuf[];
bf16_t* sTmaBuf = (bf16_t*)(sbuf);
uint64_t* sMbar = (uint64_t*)(sbuf + ((TILE_BYTES + 127) & ~127));
if (tid == 0) {
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
int phase = 0;
// Issue TMA: only thread 0
if (tid == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k, mbar_addr, 0, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_BYTES);
}
// ALL threads wait
tma_mbarrier_wait(mbar_addr, phase);
__syncthreads();
// Copy to verify buffer and check
int bad = 0;
for (int i = tid; i < TILE_ROWS * TILE_COLS; i += NUM_THREADS) {
verify_buf[i] = sTmaBuf[i];
}
}
// Approach 2: commit_group / wait_group (no mbarrier)
__global__ void __launch_bounds__(NUM_THREADS)
test_tma_commit_group_kernel(
int* __restrict__ mismatches,
bf16_t* __restrict__ verify_buf,
CUtensorMap* __restrict__ tma_k,
int s_k
) {
const int tid = threadIdx.x;
const int lane = tid % 32;
extern __shared__ __align__(128) char sbuf[];
bf16_t* sTmaBuf = (bf16_t*)(sbuf);
// Issue TMA: only thread 0
if (tid == 0) {
// TMA with mbarrier::complete_tx — we still need a mbarrier for the TMA instruction
// But we can use commit_group/wait_group for the synchronization
// Actually, cp.async.bulk.tensor REQUIRES a mbarrier. commit_group/wait_group
// is for cp.async.bulk (non-tensor). Let's use a different approach:
// Use mbarrier just for the TMA completion signal, but wait via a simpler pattern.
}
__syncthreads();
// Actually, the PTX spec says cp.async.bulk.tensor.2d REQUIRES mbarrier.
// commit_group/wait_group is for cp.async.bulk (non-tensor) only.
// So the mbarrier approach is the ONLY way for TMA.
// The bug must be in how we USE the mbarrier, not in the approach itself.
// Let's instead try: only 1 thread issues TMA AND waits, then signals others.
}
// Approach 3: Single-thread TMA issue + wait, then signal
__global__ void __launch_bounds__(NUM_THREADS)
test_tma_single_wait_kernel(
int* __restrict__ mismatches,
bf16_t* __restrict__ verify_buf,
CUtensorMap* __restrict__ tma_k,
int s_k
) {
const int tid = threadIdx.x;
const int lane = tid % 32;
extern __shared__ __align__(128) char sbuf[];
bf16_t* sTmaBuf = (bf16_t*)(sbuf);
uint64_t* sMbar = (uint64_t*)(sbuf + ((TILE_BYTES + 127) & ~127));
volatile int* sFlag = (volatile int*)(sbuf + ((TILE_BYTES + 127) & ~127) + 16);
if (tid == 0) {
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
*sFlag = 0;
}
__syncthreads();
// ONLY thread 0 issues TMA AND waits for completion
if (tid == 0) {
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
int phase = 0;
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k, mbar_addr, 0, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_BYTES);
tma_mbarrier_wait(mbar_addr, phase);
// Signal completion to other threads
asm volatile("fence.sc.gpu;" ::: "memory");
*sFlag = 1;
}
// Other threads spin-wait
if (tid != 0) {
while (*sFlag == 0) {}
}
__syncthreads();
// Copy to verify buffer
for (int i = tid; i < TILE_ROWS * TILE_COLS; i += NUM_THREADS) {
verify_buf[i] = sTmaBuf[i];
}
}
inline bool create_tma_desc_2d_bf16(
CUtensorMap* out, const void* gmem_ptr,
uint64_t rows, uint64_t cols,
uint32_t tile_rows, uint32_t tile_cols
) {
uint64_t global_dim[] = {cols, rows};
uint64_t global_str[] = {cols * 2};
uint32_t tile_dim[] = {tile_cols, tile_rows};
uint32_t tile_str[] = {1, 1};
CUresult res = cuTensorMapEncodeTiled(
out, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2,
const_cast<void*>(gmem_ptr),
global_dim, global_str, tile_dim, tile_str,
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_NONE,
CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
if (res != CUDA_SUCCESS) {
fprintf(stderr, "cuTensorMapEncodeTiled failed: %d\n", (int)res);
return false;
}
// Driver workaround for CUDA <= 13.1 + small tensors
int dv = 0; cudaDriverGetVersion(&dv);
size_t total = rows * cols * 2;
if (dv <= 13010 && total < 131072) {
reinterpret_cast<uint64_t*>(out)[1] &= ~(1ULL << 21);
}
return true;
}
int main() {
printf("=== Minimal TMA diagnostic (HD=%d, NUM_THREADS=%d) ===\n", HD, NUM_THREADS);
// Create K data: (SK, HD) = (128, 16) for the first sub-tile
const int total = SK * HD;
bf16_t* h_k = (bf16_t*)calloc(total, sizeof(bf16_t));
srand(42);
for (int i = 0; i < total; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
// We'll load just the first (128, 16) sub-tile of K
bf16_t* h_k_sub = (bf16_t*)calloc(TILE_ROWS * TILE_COLS, sizeof(bf16_t));
for (int r = 0; r < TILE_ROWS && r < SK; r++)
for (int c = 0; c < TILE_COLS && c < HD; c++)
h_k_sub[r * TILE_COLS + c] = h_k[r * HD + c];
bf16_t *d_k, *d_k_sub, *d_verify;
cudaMalloc(&d_k, total * sizeof(bf16_t));
cudaMalloc(&d_k_sub, TILE_ROWS * TILE_COLS * sizeof(bf16_t));
cudaMalloc(&d_verify, TILE_ROWS * TILE_COLS * sizeof(bf16_t));
cudaMemcpy(d_k, h_k, total * sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k_sub, h_k_sub, TILE_ROWS * TILE_COLS * sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemset(d_verify, 0, TILE_ROWS * TILE_COLS * sizeof(bf16_t));
// TMA descriptor for the sub-tile
CUtensorMap tma_k; CUtensorMap* d_tma_k;
if (!create_tma_desc_2d_bf16(&tma_k, d_k_sub, TILE_ROWS, TILE_COLS, TILE_ROWS, TILE_COLS)) {
printf("TMA desc FAILED\n"); return 1;
}
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
int* d_mismatch;
cudaMalloc(&d_mismatch, sizeof(int));
cudaMemset(d_mismatch, 0, sizeof(int));
// --- Test 1: mbarrier approach ---
printf("\n--- Test 1: mbarrier (all threads wait) ---\n");
{
size_t smem = TILE_BYTES + 128 + 16 + 4; // TMA buf + mbar + flag + padding
cudaMemset(d_verify, 0, TILE_ROWS * TILE_COLS * sizeof(bf16_t));
test_tma_mbarrier_kernel<<<1, NUM_THREADS, smem>>>(d_mismatch, d_verify, d_tma_k, SK);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf(" CUDA ERROR: %s\n", cudaGetErrorString(err));
} else {
bf16_t* h_verify = (bf16_t*)malloc(TILE_ROWS * TILE_COLS * sizeof(bf16_t));
cudaMemcpy(h_verify, d_verify, TILE_ROWS * TILE_COLS * sizeof(bf16_t), cudaMemcpyDeviceToHost);
int bad = 0;
for (int i = 0; i < TILE_ROWS * TILE_COLS; i++) {
if (h_verify[i] != h_k_sub[i]) bad++;
}
printf(" Mismatches: %d / %d\n", bad, TILE_ROWS * TILE_COLS);
if (bad > 0 && bad < 10) {
for (int i = 0; i < TILE_ROWS * TILE_COLS; i++) {
if (h_verify[i] != h_k_sub[i])
printf(" [%d]: expected %u got %u\n", i, (unsigned)h_k_sub[i], (unsigned)h_verify[i]);
}
}
free(h_verify);
}
}
// --- Test 3: single-thread wait ---
printf("\n--- Test 3: single-thread TMA issue+wait, flag signal ---\n");
{
size_t smem = TILE_BYTES + 128 + 16 + 4;
cudaMemset(d_verify, 0, TILE_ROWS * TILE_COLS * sizeof(bf16_t));
test_tma_single_wait_kernel<<<1, NUM_THREADS, smem>>>(d_mismatch, d_verify, d_tma_k, SK);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf(" CUDA ERROR: %s\n", cudaGetErrorString(err));
} else {
bf16_t* h_verify = (bf16_t*)malloc(TILE_ROWS * TILE_COLS * sizeof(bf16_t));
cudaMemcpy(h_verify, d_verify, TILE_ROWS * TILE_COLS * sizeof(bf16_t), cudaMemcpyDeviceToHost);
int bad = 0;
for (int i = 0; i < TILE_ROWS * TILE_COLS; i++) {
if (h_verify[i] != h_k_sub[i]) bad++;
}
printf(" Mismatches: %d / %d\n", bad, TILE_ROWS * TILE_COLS);
free(h_verify);
}
}
cudaFree(d_k); cudaFree(d_k_sub); cudaFree(d_verify); cudaFree(d_tma_k); cudaFree(d_mismatch);
free(h_k); free(h_k_sub);
printf("\nDone.\n");
return 0;
}