Files
nvfp4-megamoe-kernel/tests/unit/test_tma_subtile.cu

194 lines
6.8 KiB
Plaintext

/**
* TMA sub-tile load test — replicate the exact pattern from test_qk_softmax.
* TMA descriptor for (SK, HD), load sub-tile at (coord_x=kt*16, coord_y=0).
* Compare against working test_fmha_6warp (direct GMEM load).
*/
#include <cuda_runtime.h>
#include <cuda.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#ifndef HD_VAL
#define HD_VAL 64
#endif
#ifndef NUM_THREADS
#define NUM_THREADS 128
#endif
typedef unsigned short bf16_t;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = HD_VAL;
constexpr int SK = 128;
constexpr int MMA_K_BF16 = 16;
constexpr int TILE_SZ = 128 * MMA_K_BF16; // 2048 BF16
__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar, uint32_t count) {
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "r"(smem_mbar), "r"(count));
}
__device__ __forceinline__ void tma_mbarrier_arrive_expect_tx(uint32_t smem_mbar, uint32_t tx_bytes) {
asm volatile("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;"
:: "r"(smem_mbar), "r"(tx_bytes) : "memory");
}
__device__ __forceinline__ void tma_load_2d(
uint32_t smem_dst, uint64_t tma_desc, uint32_t smem_mbar,
int coord_x, int coord_y
) {
asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes "
"[%0], [%1, {%3, %4}], [%2];"
:: "r"(smem_dst), "l"(tma_desc), "r"(smem_mbar), "r"(coord_x), "r"(coord_y)
: "memory"
);
}
__device__ __forceinline__ void tma_mbarrier_wait(uint32_t smem_mbar, int phase) {
asm volatile(
"{\n\t"
".reg .pred P1;\n\t"
"LAB_WAIT:"
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n\t"
"@P1 bra.uni DONE;\n\t"
"bra.uni LAB_WAIT;\n\t"
"DONE:\n\t"
"}"
:: "r"(smem_mbar), "r"(phase), "r"(0x989680)
: "memory"
);
}
/**
* Test: TMA load a sub-tile from a (SK, HD) K matrix.
* This is the EXACT pattern from test_qk_softmax.cu.
*/
__global__ void __launch_bounds__(NUM_THREADS)
test_tma_subtile_kernel(
bf16_t* __restrict__ verify_buf, // what we loaded via TMA
CUtensorMap* __restrict__ tma_k,
int s_k, int n_kt
) {
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane = tid % 32;
extern __shared__ __align__(128) char sbuf[];
size_t off = 0;
bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
off = (off + 15) & ~(size_t)15;
uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8;
if (tid == 0) {
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
int phase = 0;
for (int kt = 0; kt < n_kt; kt++) {
// Load K sub-tile via TMA — same as test_qk_softmax
if (wid == 0 && lane == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k,
mbar_addr, kt * MMA_K_BF16, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t));
}
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
__syncthreads();
// Copy sub-tile to verify buffer
int base = kt * TILE_SZ;
for (int i = tid; i < TILE_SZ; i += NUM_THREADS) {
verify_buf[base + i] = sTmaBuf[i];
}
__syncthreads();
}
}
inline bool create_tma_desc_2d_bf16(
CUtensorMap* out, const void* gmem_ptr,
uint64_t rows, uint64_t cols,
uint32_t tile_rows, uint32_t tile_cols
) {
uint64_t global_dim[] = {cols, rows};
uint64_t global_str[] = {cols * 2};
uint32_t tile_dim[] = {tile_cols, tile_rows};
uint32_t tile_str[] = {1, 1};
CUresult res = cuTensorMapEncodeTiled(
out, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2,
const_cast<void*>(gmem_ptr),
global_dim, global_str, tile_dim, tile_str,
CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_NONE,
CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
if (res != CUDA_SUCCESS) {
fprintf(stderr, "cuTensorMapEncodeTiled failed: %d\n", (int)res);
return false;
}
int dv = 0; cudaDriverGetVersion(&dv);
size_t total = rows * cols * 2;
if (dv <= 13010 && total < 131072) {
reinterpret_cast<uint64_t*>(out)[1] &= ~(1ULL << 21);
}
return true;
}
int main() {
printf("=== TMA sub-tile load test (HD=%d, NUM_THREADS=%d) ===\n", HD, NUM_THREADS);
const int n_kt = HD / MMA_K_BF16; // number of K sub-tiles
bf16_t* h_k = (bf16_t*)calloc(SK * HD, sizeof(bf16_t));
srand(42);
for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
bf16_t *d_k, *d_verify;
cudaMalloc(&d_k, SK * HD * sizeof(bf16_t));
cudaMalloc(&d_verify, n_kt * TILE_SZ * sizeof(bf16_t));
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemset(d_verify, 0, n_kt * TILE_SZ * sizeof(bf16_t));
// TMA descriptor for full K matrix: (SK, HD) with tile (128, 16)
CUtensorMap tma_k; CUtensorMap* d_tma_k;
if (!create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, 128, 16)) {
printf("TMA desc FAILED\n"); return 1;
}
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
size_t smem = TILE_SZ * 2 + 16 + 128;
test_tma_subtile_kernel<<<1, NUM_THREADS, smem>>>(d_verify, d_tma_k, SK, n_kt);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1;
}
// Verify: each sub-tile should contain K[0:128, kt*16:kt*16+16]
bf16_t* h_verify = (bf16_t*)malloc(n_kt * TILE_SZ * sizeof(bf16_t));
cudaMemcpy(h_verify, d_verify, n_kt * TILE_SZ * sizeof(bf16_t), cudaMemcpyDeviceToHost);
int total_bad = 0;
for (int kt = 0; kt < n_kt; kt++) {
int bad = 0;
for (int r = 0; r < 128; r++) {
for (int c = 0; c < MMA_K_BF16; c++) {
bf16_t expected = h_k[r * HD + kt * MMA_K_BF16 + c];
bf16_t got = h_verify[kt * TILE_SZ + r * MMA_K_BF16 + c];
if (got != expected) bad++;
}
}
if (bad > 0) printf(" kt=%d: %d mismatches\n", kt, bad);
total_bad += bad;
}
printf("Total mismatches: %d\n", total_bad);
printf("%s\n", total_bad == 0 ? "PASSED" : "FAILED");
cudaFree(d_k); cudaFree(d_verify); cudaFree(d_tma_k);
free(h_k); free(h_verify);
return total_bad == 0 ? 0 : 1;
}