Files
nvfp4-megamoe-kernel/tests/unit/test_tma_konly.cu

204 lines
7.7 KiB
Plaintext

/**
* TMA FMHA kernel — built on top of the proven test_fmha_gen pattern.
* Step 1: Add TMA for K loads only (Q stays as direct load for T=1 decode).
* Once K-TMA works, add Q-TMA for multi-row prefill.
*/
#include <cuda_runtime.h>
#include <cuda.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#ifndef HD_VAL
#define HD_VAL 64
#endif
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
#include "dsv4/kernels/attention/fmha_tma.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = HD_VAL;
constexpr int SK = 128;
constexpr int NKT = HD / MMA_K_BF16;
constexpr int BLOCK_MN = 128;
constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16;
constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
constexpr int CORES_MN = BLOCK_MN / 8; // 16
constexpr int NUM_READS = SK / 8;
// ===== TMA K-load variant =====
__global__ void __launch_bounds__(128)
fmha_tma_konly_kernel(
float* __restrict__ out_s,
const bf16_t* __restrict__ q,
CUtensorMap* __restrict__ tma_k,
int s_k
) {
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane = tid % 32;
// SMEM — carefully aligned
extern __shared__ __align__(128) char sbuf[];
size_t off = 0;
uint32_t* sTmemBase = (uint32_t*)sbuf; off = 4;
// 16-byte align for Q0
off = (off + 15) & ~(size_t)15;
bf16_t* sQ0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
bf16_t* sK0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
// 128-byte align for TMA buffer
off = (off + 127) & ~(size_t)127;
bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
// 16-byte align for mbarrier
off = (off + 15) & ~(size_t)15;
uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8;
// Init
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
if (tid == 0) {
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
uint32_t tb = *sTmemBase;
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
int phase = 0;
// ===== QK GEMM =====
{
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < 1; kt++) { // Only 1 K sub-tile for now
// Load Q sub-tile: direct from GMEM (T=1, only row 0)
for (int i = tid; i < TILE_SZ; i += 128) sQ0[i] = 0;
for (int d = tid; d < MMA_K_BF16; d += 128) {
int full_d = kt * MMA_K_BF16 + d;
if (full_d < HD) {
int ck = d / 8, lc = d % 8;
sQ0[ck * CORES_MN * 64 + lc] = q[full_d];
}
}
__syncthreads();
// Load K sub-tile via TMA — only warp 0, lane 0 issues
if (wid == 0 && lane == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k, mbar_addr, kt * MMA_K_BF16, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t));
}
// ALL threads wait for TMA completion
tma_mbarrier_wait(mbar_addr, phase);
phase ^= 1;
__syncthreads();
// Convert TMA row-major → canonical
for (int i = tid; i < TILE_SZ; i += 128) sK0[i] = 0;
for (int i = tid; i < SK * MMA_K_BF16; i += 128) {
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
int ck = c / 8, lc = c % 8, tmn = r / 8, lr = r % 8;
sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = sTmaBuf[i];
}
__syncthreads();
// MMA
if (tid == 0) {
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN);
umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
}
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
// ===== Read S from TMEM =====
if (wid == 0) {
for (int n = 0; n < NUM_READS; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
// For T=1, only lane 0's data matters
if (lane == 0) {
for (int c = 0; c < 8; c++) {
int col = n * 8 + c;
if (col < s_k) out_s[col] = tmp[c];
}
}
}
}
if (wid == 0) tmem_dealloc(tb, TMEM_N);
}
int main() {
printf("TMA K-only FMHA (HD=%d)\n", HD);
const int T = 1;
bf16_t* h_q = (bf16_t*)calloc(HD, sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)calloc(SK * HD, sizeof(bf16_t));
srand(42);
for (int i = 0; i < HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
bf16_t *d_q, *d_k; float *d_out;
cudaMalloc(&d_q, HD * sizeof(bf16_t));
cudaMalloc(&d_k, SK * HD * sizeof(bf16_t));
cudaMalloc(&d_out, SK * sizeof(float));
cudaMemcpy(d_q, h_q, HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
// Try simple TMA: (128, 16) descriptor for just the first K sub-tile
// If this works, the issue is with (128, 64) descriptors
bf16_t* d_k_sub;
cudaMalloc(&d_k_sub, SK * MMA_K_BF16 * sizeof(bf16_t));
// Copy first sub-tile of K
cudaMemcpy(d_k_sub, d_k, SK * MMA_K_BF16 * sizeof(bf16_t), cudaMemcpyDeviceToDevice);
CUtensorMap tma_k;
CUtensorMap* d_tma_k;
if (!create_tma_desc_2d_bf16(&tma_k, d_k_sub, SK, (uint64_t)MMA_K_BF16, BLOCK_MN, MMA_K_BF16)) {
printf("TMA K desc FAILED\n"); return 1;
}
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
int smem = 4 + 16 + TILE_SZ*2 + 128 + TILE_SZ + 16 + 8 + 4096;
cudaFuncSetAttribute(fmha_tma_konly_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
fmha_tma_konly_kernel<<<1, 128, smem>>>(d_out, d_q, d_tma_k, SK);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
float* h_out = (float*)malloc(SK * sizeof(float));
cudaMemcpy(h_out, d_out, SK * sizeof(float), cudaMemcpyDeviceToHost);
// Reference
float scale = 1.0f / sqrtf((float)HD);
int fail = 0; float max_rel = 0;
for (int j = 0; j < SK; j++) {
float dot = 0;
for (int d = 0; d < HD; d++)
dot += bf16_to_f32_host(h_q[d]) * bf16_to_f32_host(h_k[j * HD + d]);
float ref = dot * scale;
float got = h_out[j];
float rel = fabsf(ref) > 1e-4f ? fabsf(got - ref) / fabsf(ref) : fabsf(got - ref);
if (rel > max_rel) max_rel = rel;
if (rel > 0.01f && fail < 5) printf(" j=%d: ref=%.6f got=%.6f rel=%.4f\n", j, ref, got, rel);
if (rel > 0.01f) fail++;
}
printf("Max rel err: %.8f, failures: %d\n", max_rel, fail);
printf("%s\n", fail == 0 ? "PASSED" : "FAILED");
return fail == 0 ? 0 : 1;
}