134 lines
4.7 KiB
Plaintext
134 lines
4.7 KiB
Plaintext
/**
|
|
* Absolute minimum TMA + QK test:
|
|
* 1. Load K via TMA into sTmaBuf
|
|
* 2. Convert to canonical sK0
|
|
* 3. Write sK0 to GMEM for verification
|
|
* No TMEM, no MMA.
|
|
*/
|
|
|
|
#include <cuda_runtime.h>
|
|
#include <cuda.h>
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
|
|
#ifndef HD_VAL
|
|
#define HD_VAL 64
|
|
#endif
|
|
|
|
#include "dsv4/kernels/attention/fmha_common.cuh"
|
|
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
|
#include "dsv4/kernels/attention/fmha_tma.cuh"
|
|
|
|
using namespace dsv4::kernels::attention;
|
|
|
|
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
|
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
|
|
|
constexpr int HD = HD_VAL;
|
|
constexpr int SK = 128;
|
|
constexpr int MMA_K = MMA_K_BF16;
|
|
constexpr int CORES_MN = 16;
|
|
constexpr int TILE_SZ = 128 * MMA_K;
|
|
|
|
__global__ void __launch_bounds__(128)
|
|
test_tma_kload_kernel(
|
|
bf16_t* __restrict__ out_canonical, // (128, 16) canonical
|
|
bf16_t* __restrict__ out_rowmajor, // (128, 16) row-major
|
|
CUtensorMap* __restrict__ tma_k
|
|
) {
|
|
const int tid = threadIdx.x;
|
|
const int wid = tid / 32;
|
|
const int lane = tid % 32;
|
|
|
|
// Simple SMEM: sTmaBuf (128,16) + sCanonical (128,16) + sMbar
|
|
extern __shared__ __align__(128) char sbuf[];
|
|
bf16_t* sTmaBuf = (bf16_t*)(sbuf);
|
|
bf16_t* sCanonical = (bf16_t*)(sbuf + TILE_SZ * sizeof(bf16_t));
|
|
uint64_t* sMbar = (uint64_t*)(sbuf + 2 * TILE_SZ * sizeof(bf16_t));
|
|
|
|
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
|
|
|
if (tid == 0) {
|
|
tma_mbarrier_init(mbar_addr, 1);
|
|
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
|
|
}
|
|
__syncthreads();
|
|
|
|
// TMA load: (128, 16) tile at coord {0, 0}
|
|
if (wid == 0 && lane == 0) {
|
|
uint32_t smem_dst = (uint32_t)__cvta_generic_to_shared(sTmaBuf);
|
|
tma_load_2d(smem_dst, (uint64_t)tma_k, mbar_addr, 0, 0);
|
|
tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t));
|
|
}
|
|
tma_mbarrier_wait(mbar_addr, 0);
|
|
__syncthreads();
|
|
|
|
// Convert to canonical
|
|
for (int i = tid; i < TILE_SZ; i += 128) sCanonical[i] = 0;
|
|
for (int i = tid; i < SK * MMA_K; i += 128) {
|
|
int r = i / MMA_K, c = i % MMA_K;
|
|
int ck = c / 8, lc = c % 8, tmn = r / 8, lr = r % 8;
|
|
sCanonical[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = sTmaBuf[i];
|
|
}
|
|
__syncthreads();
|
|
|
|
// Write canonical to GMEM
|
|
for (int i = tid; i < TILE_SZ; i += 128) out_canonical[i] = sCanonical[i];
|
|
// Write row-major to GMEM
|
|
for (int i = tid; i < TILE_SZ; i += 128) out_rowmajor[i] = sTmaBuf[i];
|
|
}
|
|
|
|
int main() {
|
|
printf("TMA K-load only (HD=%d)\n", HD);
|
|
|
|
bf16_t* h_k = (bf16_t*)calloc(SK * MMA_K, sizeof(bf16_t));
|
|
srand(42);
|
|
for (int i = 0; i < SK * MMA_K; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
|
|
|
|
bf16_t *d_k, *d_out_c, *d_out_r;
|
|
cudaMalloc(&d_k, SK * MMA_K * sizeof(bf16_t));
|
|
cudaMalloc(&d_out_c, TILE_SZ * sizeof(bf16_t));
|
|
cudaMalloc(&d_out_r, TILE_SZ * sizeof(bf16_t));
|
|
cudaMemcpy(d_k, h_k, SK * MMA_K * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
|
|
|
CUtensorMap tma_k;
|
|
CUtensorMap* d_tma_k;
|
|
if (!create_tma_desc_2d_bf16(&tma_k, d_k, SK, (uint64_t)MMA_K, 128, 16)) {
|
|
printf("TMA desc FAILED\n"); return 1;
|
|
}
|
|
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
|
|
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
|
|
|
int smem = TILE_SZ * 2 * sizeof(bf16_t) + 16;
|
|
test_tma_kload_kernel<<<1, 128, smem>>>(d_out_c, d_out_r, d_tma_k);
|
|
cudaError_t err = cudaDeviceSynchronize();
|
|
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
|
|
|
bf16_t* h_c = (bf16_t*)malloc(TILE_SZ * sizeof(bf16_t));
|
|
bf16_t* h_r = (bf16_t*)malloc(TILE_SZ * sizeof(bf16_t));
|
|
cudaMemcpy(h_c, d_out_c, TILE_SZ * sizeof(bf16_t), cudaMemcpyDeviceToHost);
|
|
cudaMemcpy(h_r, d_out_r, TILE_SZ * sizeof(bf16_t), cudaMemcpyDeviceToHost);
|
|
|
|
// Verify row-major
|
|
int rm_mismatches = 0;
|
|
for (int i = 0; i < SK * MMA_K; i++) {
|
|
if (h_r[i] != h_k[i]) rm_mismatches++;
|
|
}
|
|
printf("Row-major: %d mismatches out of %d\n", rm_mismatches, SK * MMA_K);
|
|
|
|
// Verify canonical
|
|
int cn_mismatches = 0;
|
|
for (int r = 0; r < SK; r++) {
|
|
for (int c = 0; c < MMA_K; c++) {
|
|
int ck = c/8, lc = c%8, tmn = r/8, lr = r%8;
|
|
int canon_idx = ck*CORES_MN*64 + tmn*64 + lr*8 + lc;
|
|
if (h_k[r*MMA_K + c] != h_c[canon_idx]) cn_mismatches++;
|
|
}
|
|
}
|
|
printf("Canonical: %d mismatches out of %d\n", cn_mismatches, SK * MMA_K);
|
|
|
|
printf("%s\n", (rm_mismatches + cn_mismatches == 0) ? "PASSED" : "FAILED");
|
|
return (rm_mismatches + cn_mismatches == 0) ? 0 : 1;
|
|
}
|