151 lines
5.5 KiB
Plaintext
151 lines
5.5 KiB
Plaintext
/**
|
|
* Verify TMA load correctness by comparing against direct GMEM read.
|
|
* Isolates TMA + mbarrier + canonical write from the rest of the FMHA pipeline.
|
|
*/
|
|
|
|
#include <cuda_runtime.h>
|
|
#include <cuda.h>
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
|
|
#include "dsv4/kernels/attention/fmha_common.cuh"
|
|
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
|
#include "dsv4/kernels/attention/fmha_tma.cuh"
|
|
|
|
using namespace dsv4::kernels::attention;
|
|
|
|
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
|
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
|
|
|
constexpr int ROWS = 128;
|
|
constexpr int COLS = 16; // one TMA sub-tile
|
|
|
|
/**
|
|
* Test 1: TMA load of (128, 16) tile, compare against direct read.
|
|
*/
|
|
__global__ void test_tma_load_kernel(
|
|
const bf16_t* __restrict__ gmem_data,
|
|
bf16_t* __restrict__ result_canonical, // output: canonical layout from TMA path
|
|
bf16_t* __restrict__ result_direct, // output: canonical layout from direct path
|
|
CUtensorMap* __restrict__ tma_desc
|
|
) {
|
|
extern __shared__ __align__(128) char sbuf[];
|
|
size_t off = 0;
|
|
bf16_t* sData_tma = (bf16_t*)(sbuf + off); off += ROWS * COLS * sizeof(bf16_t);
|
|
off = (off + 127) & ~(size_t)127; // 128-byte align for TMA
|
|
bf16_t* sData_canonical = (bf16_t*)(sbuf + off);
|
|
|
|
const int tid = threadIdx.x;
|
|
const int lane = tid % 32;
|
|
const int wid = tid / 32;
|
|
|
|
// Set up mbarrier
|
|
__shared__ uint64_t sMbar[1];
|
|
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
|
|
|
if (tid == 0) {
|
|
tma_mbarrier_init(mbar_addr, 1);
|
|
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
|
|
}
|
|
__syncthreads();
|
|
|
|
// --- TMA path ---
|
|
if (wid == 0 && lane == 0) {
|
|
uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sData_tma);
|
|
tma_load_2d(smem_dst, (uint64_t)tma_desc, mbar_addr, 0, 0);
|
|
tma_mbarrier_arrive_expect_tx(mbar_addr, ROWS * COLS * 2);
|
|
}
|
|
tma_mbarrier_wait(mbar_addr, 0);
|
|
__syncthreads();
|
|
|
|
// Convert to canonical
|
|
if (wid == 0) write_smem_canonical<ROWS, COLS, 32>(sData_canonical, sData_tma);
|
|
__syncthreads();
|
|
|
|
// Write canonical result to GMEM
|
|
for (int i = tid; i < ROWS * COLS; i += 128) {
|
|
result_canonical[i] = sData_canonical[i];
|
|
}
|
|
|
|
// --- Direct path ---
|
|
// Zero sData_canonical, fill from GMEM, convert
|
|
if (wid == 0) {
|
|
for (int i = lane; i < ROWS * COLS; i += 32) {
|
|
sData_tma[i] = gmem_data[i]; // direct GMEM read (row-major)
|
|
}
|
|
}
|
|
__syncthreads();
|
|
|
|
if (wid == 0) write_smem_canonical<ROWS, COLS, 32>(sData_canonical, sData_tma);
|
|
__syncthreads();
|
|
|
|
for (int i = tid; i < ROWS * COLS; i += 128) {
|
|
result_direct[i] = sData_canonical[i];
|
|
}
|
|
}
|
|
|
|
int main() {
|
|
printf("TMA Load Verification Test\n");
|
|
|
|
// Allocate host data
|
|
bf16_t* h_data = (bf16_t*)malloc(ROWS * COLS * sizeof(bf16_t));
|
|
srand(42);
|
|
for (int i = 0; i < ROWS * COLS; i++) h_data[i] = f32_to_bf16_host((float)(rand() % 100) / 100.0f - 0.5f);
|
|
|
|
// Allocate device
|
|
bf16_t *d_data, *d_result_tma, *d_result_direct;
|
|
cudaMalloc(&d_data, ROWS * COLS * sizeof(bf16_t));
|
|
cudaMalloc(&d_result_tma, ROWS * COLS * sizeof(bf16_t));
|
|
cudaMalloc(&d_result_direct, ROWS * COLS * sizeof(bf16_t));
|
|
cudaMemcpy(d_data, h_data, ROWS * COLS * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
|
|
|
// Create TMA descriptor for (128, 16) BF16
|
|
CUtensorMap tma_desc;
|
|
CUtensorMap* d_tma_desc;
|
|
if (!create_tma_desc_2d_bf16(&tma_desc, d_data, ROWS, COLS, ROWS, COLS)) {
|
|
printf("TMA descriptor creation FAILED\n");
|
|
return 1;
|
|
}
|
|
cudaMalloc(&d_tma_desc, sizeof(CUtensorMap));
|
|
cudaMemcpy(d_tma_desc, &tma_desc, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
|
|
|
int smem = ROWS * COLS * 2 * 2 + 128; // two SMEM buffers + slack
|
|
test_tma_load_kernel<<<1, 128, smem>>>(d_data, d_result_tma, d_result_direct, d_tma_desc);
|
|
cudaError_t err = cudaDeviceSynchronize();
|
|
if (err != cudaSuccess) {
|
|
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
|
|
return 1;
|
|
}
|
|
|
|
// Compare
|
|
bf16_t* h_tma = (bf16_t*)malloc(ROWS * COLS * sizeof(bf16_t));
|
|
bf16_t* h_direct = (bf16_t*)malloc(ROWS * COLS * sizeof(bf16_t));
|
|
cudaMemcpy(h_tma, d_result_tma, ROWS * COLS * sizeof(bf16_t), cudaMemcpyDeviceToHost);
|
|
cudaMemcpy(h_direct, d_result_direct, ROWS * COLS * sizeof(bf16_t), cudaMemcpyDeviceToHost);
|
|
|
|
int mismatches = 0;
|
|
for (int i = 0; i < ROWS * COLS; i++) {
|
|
if (h_tma[i] != h_direct[i]) mismatches++;
|
|
}
|
|
printf("TMA vs Direct canonical: %d mismatches out of %d\n", mismatches, ROWS * COLS);
|
|
|
|
// Also compare TMA canonical against host data directly (row-major → canonical)
|
|
// For (128, 16): CORES_MN=16, CORES_K=2
|
|
// canonical[core_k * 16 * 64 + core_mn * 64 + local_r * 8 + local_c]
|
|
int canonical_mismatches = 0;
|
|
for (int i = 0; i < ROWS * COLS; i++) {
|
|
int r = i / COLS, c = i % COLS;
|
|
int core_mn = r / 8, core_k = c / 8, local_r = r % 8, local_c = c % 8;
|
|
int canon_idx = core_k * 16 * 64 + core_mn * 64 + local_r * 8 + local_c;
|
|
if (h_data[i] != h_tma[canon_idx]) canonical_mismatches++;
|
|
}
|
|
printf("Host data → canonical: %d mismatches (checks mapping correctness)\n", canonical_mismatches);
|
|
|
|
printf("%s\n", mismatches == 0 ? "PASSED" : "FAILED");
|
|
|
|
cudaFree(d_data); cudaFree(d_result_tma); cudaFree(d_result_direct); cudaFree(d_tma_desc);
|
|
free(h_data); free(h_tma); free(h_direct);
|
|
return mismatches == 0 ? 0 : 1;
|
|
}
|