194 lines
6.8 KiB
Plaintext
194 lines
6.8 KiB
Plaintext
/**
|
|
* TMA sub-tile load test — replicate the exact pattern from test_qk_softmax.
|
|
* TMA descriptor for (SK, HD), load sub-tile at (coord_x=kt*16, coord_y=0).
|
|
* Compare against working test_fmha_6warp (direct GMEM load).
|
|
*/
|
|
|
|
#include <cuda_runtime.h>
|
|
#include <cuda.h>
|
|
#include <cstdio>
|
|
#include <cmath>
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
|
|
#ifndef HD_VAL
|
|
#define HD_VAL 64
|
|
#endif
|
|
|
|
#ifndef NUM_THREADS
|
|
#define NUM_THREADS 128
|
|
#endif
|
|
|
|
typedef unsigned short bf16_t;
|
|
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
|
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
|
|
|
constexpr int HD = HD_VAL;
|
|
constexpr int SK = 128;
|
|
constexpr int MMA_K_BF16 = 16;
|
|
constexpr int TILE_SZ = 128 * MMA_K_BF16; // 2048 BF16
|
|
|
|
__device__ __forceinline__ void tma_mbarrier_init(uint32_t smem_mbar, uint32_t count) {
|
|
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "r"(smem_mbar), "r"(count));
|
|
}
|
|
__device__ __forceinline__ void tma_mbarrier_arrive_expect_tx(uint32_t smem_mbar, uint32_t tx_bytes) {
|
|
asm volatile("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;"
|
|
:: "r"(smem_mbar), "r"(tx_bytes) : "memory");
|
|
}
|
|
__device__ __forceinline__ void tma_load_2d(
|
|
uint32_t smem_dst, uint64_t tma_desc, uint32_t smem_mbar,
|
|
int coord_x, int coord_y
|
|
) {
|
|
asm volatile(
|
|
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes "
|
|
"[%0], [%1, {%3, %4}], [%2];"
|
|
:: "r"(smem_dst), "l"(tma_desc), "r"(smem_mbar), "r"(coord_x), "r"(coord_y)
|
|
: "memory"
|
|
);
|
|
}
|
|
__device__ __forceinline__ void tma_mbarrier_wait(uint32_t smem_mbar, int phase) {
|
|
asm volatile(
|
|
"{\n\t"
|
|
".reg .pred P1;\n\t"
|
|
"LAB_WAIT:"
|
|
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n\t"
|
|
"@P1 bra.uni DONE;\n\t"
|
|
"bra.uni LAB_WAIT;\n\t"
|
|
"DONE:\n\t"
|
|
"}"
|
|
:: "r"(smem_mbar), "r"(phase), "r"(0x989680)
|
|
: "memory"
|
|
);
|
|
}
|
|
|
|
/**
|
|
* Test: TMA load a sub-tile from a (SK, HD) K matrix.
|
|
* This is the EXACT pattern from test_qk_softmax.cu.
|
|
*/
|
|
__global__ void __launch_bounds__(NUM_THREADS)
|
|
test_tma_subtile_kernel(
|
|
bf16_t* __restrict__ verify_buf, // what we loaded via TMA
|
|
CUtensorMap* __restrict__ tma_k,
|
|
int s_k, int n_kt
|
|
) {
|
|
const int tid = threadIdx.x;
|
|
const int wid = tid / 32;
|
|
const int lane = tid % 32;
|
|
|
|
extern __shared__ __align__(128) char sbuf[];
|
|
size_t off = 0;
|
|
bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
|
off = (off + 15) & ~(size_t)15;
|
|
uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8;
|
|
|
|
if (tid == 0) {
|
|
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
|
|
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
|
|
}
|
|
__syncthreads();
|
|
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
|
int phase = 0;
|
|
|
|
for (int kt = 0; kt < n_kt; kt++) {
|
|
// Load K sub-tile via TMA — same as test_qk_softmax
|
|
if (wid == 0 && lane == 0) {
|
|
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k,
|
|
mbar_addr, kt * MMA_K_BF16, 0);
|
|
tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t));
|
|
}
|
|
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
|
|
__syncthreads();
|
|
|
|
// Copy sub-tile to verify buffer
|
|
int base = kt * TILE_SZ;
|
|
for (int i = tid; i < TILE_SZ; i += NUM_THREADS) {
|
|
verify_buf[base + i] = sTmaBuf[i];
|
|
}
|
|
__syncthreads();
|
|
}
|
|
}
|
|
|
|
inline bool create_tma_desc_2d_bf16(
|
|
CUtensorMap* out, const void* gmem_ptr,
|
|
uint64_t rows, uint64_t cols,
|
|
uint32_t tile_rows, uint32_t tile_cols
|
|
) {
|
|
uint64_t global_dim[] = {cols, rows};
|
|
uint64_t global_str[] = {cols * 2};
|
|
uint32_t tile_dim[] = {tile_cols, tile_rows};
|
|
uint32_t tile_str[] = {1, 1};
|
|
CUresult res = cuTensorMapEncodeTiled(
|
|
out, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2,
|
|
const_cast<void*>(gmem_ptr),
|
|
global_dim, global_str, tile_dim, tile_str,
|
|
CU_TENSOR_MAP_INTERLEAVE_NONE,
|
|
CU_TENSOR_MAP_SWIZZLE_NONE,
|
|
CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
|
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
|
|
);
|
|
if (res != CUDA_SUCCESS) {
|
|
fprintf(stderr, "cuTensorMapEncodeTiled failed: %d\n", (int)res);
|
|
return false;
|
|
}
|
|
int dv = 0; cudaDriverGetVersion(&dv);
|
|
size_t total = rows * cols * 2;
|
|
if (dv <= 13010 && total < 131072) {
|
|
reinterpret_cast<uint64_t*>(out)[1] &= ~(1ULL << 21);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
int main() {
|
|
printf("=== TMA sub-tile load test (HD=%d, NUM_THREADS=%d) ===\n", HD, NUM_THREADS);
|
|
const int n_kt = HD / MMA_K_BF16; // number of K sub-tiles
|
|
|
|
bf16_t* h_k = (bf16_t*)calloc(SK * HD, sizeof(bf16_t));
|
|
srand(42);
|
|
for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
|
|
|
|
bf16_t *d_k, *d_verify;
|
|
cudaMalloc(&d_k, SK * HD * sizeof(bf16_t));
|
|
cudaMalloc(&d_verify, n_kt * TILE_SZ * sizeof(bf16_t));
|
|
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
|
cudaMemset(d_verify, 0, n_kt * TILE_SZ * sizeof(bf16_t));
|
|
|
|
// TMA descriptor for full K matrix: (SK, HD) with tile (128, 16)
|
|
CUtensorMap tma_k; CUtensorMap* d_tma_k;
|
|
if (!create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, 128, 16)) {
|
|
printf("TMA desc FAILED\n"); return 1;
|
|
}
|
|
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
|
|
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
|
|
|
size_t smem = TILE_SZ * 2 + 16 + 128;
|
|
test_tma_subtile_kernel<<<1, NUM_THREADS, smem>>>(d_verify, d_tma_k, SK, n_kt);
|
|
cudaError_t err = cudaDeviceSynchronize();
|
|
if (err != cudaSuccess) {
|
|
printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1;
|
|
}
|
|
|
|
// Verify: each sub-tile should contain K[0:128, kt*16:kt*16+16]
|
|
bf16_t* h_verify = (bf16_t*)malloc(n_kt * TILE_SZ * sizeof(bf16_t));
|
|
cudaMemcpy(h_verify, d_verify, n_kt * TILE_SZ * sizeof(bf16_t), cudaMemcpyDeviceToHost);
|
|
|
|
int total_bad = 0;
|
|
for (int kt = 0; kt < n_kt; kt++) {
|
|
int bad = 0;
|
|
for (int r = 0; r < 128; r++) {
|
|
for (int c = 0; c < MMA_K_BF16; c++) {
|
|
bf16_t expected = h_k[r * HD + kt * MMA_K_BF16 + c];
|
|
bf16_t got = h_verify[kt * TILE_SZ + r * MMA_K_BF16 + c];
|
|
if (got != expected) bad++;
|
|
}
|
|
}
|
|
if (bad > 0) printf(" kt=%d: %d mismatches\n", kt, bad);
|
|
total_bad += bad;
|
|
}
|
|
printf("Total mismatches: %d\n", total_bad);
|
|
printf("%s\n", total_bad == 0 ? "PASSED" : "FAILED");
|
|
|
|
cudaFree(d_k); cudaFree(d_verify); cudaFree(d_tma_k);
|
|
free(h_k); free(h_verify);
|
|
return total_bad == 0 ? 0 : 1;
|
|
}
|