Files
nvfp4-megamoe-kernel/tests/unit/test_qk_softmax.cu

230 lines
9.1 KiB
Plaintext

/**
* Test QK + Softmax: load Q and K, compute QK GEMM, softmax, write P to GMEM.
* Verify P values against reference. No PV.
*/
#include <cuda_runtime.h>
#include <cuda.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#ifndef HD_VAL
#define HD_VAL 64
#endif
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
#include "dsv4/kernels/attention/fmha_tma.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = HD_VAL;
constexpr int SK = 128;
constexpr int NKT = HD / MMA_K_BF16;
constexpr int BLOCK_MN = 128;
constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16;
constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
constexpr int CORES_MN = 16;
constexpr int NUM_READS = SK / 8;
__global__ void __launch_bounds__(128)
test_qk_softmax_kernel(
float* __restrict__ out_p, // (T, SK) — softmax P values
const bf16_t* __restrict__ q,
CUtensorMap* __restrict__ tma_k,
int T, int s_k, float scale
) {
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ __align__(128) char sbuf[];
size_t off = 0;
uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off = 4;
off = (off + 127) & ~(size_t)127;
bf16_t* sQ0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
bf16_t* sK0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
off = (off + 15) & ~(size_t)15;
uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8;
float* sRowMax = (float*)(sbuf + off); off += 128 * sizeof(float);
float* sRowSum = (float*)(sbuf + off); off += 128 * sizeof(float);
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
if (tid == 0) {
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
uint32_t tb = *sTmemBase;
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
int phase = 0;
const bool my_warp_active = (T <= 32) ? (wid == 0) : (wid < 4);
const int my_row = my_warp_active ? (wid * 32 + lane) : 0;
const bool my_row_active = my_warp_active && (my_row < T);
// QK GEMM
{
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < NKT; kt++) {
for (int i = tid; i < TILE_SZ; i += 128) sQ0[i] = 0;
for (int d = tid; d < T * MMA_K_BF16; d += 128) {
int r = d / MMA_K_BF16, c = d % MMA_K_BF16;
int full_d = kt * MMA_K_BF16 + c;
if (full_d < HD && r < T) {
int ck = c / 8, lc = c % 8, cm = r / 8, lr = r % 8;
sQ0[ck * CORES_MN * 64 + cm * 64 + lr * 8 + lc] = q[r * HD + full_d];
}
}
__syncthreads();
if (wid == 0 && lane == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k, mbar_addr, kt * MMA_K_BF16, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t));
}
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
__syncthreads();
for (int i = tid; i < TILE_SZ; i += 128) sK0[i] = 0;
for (int i = tid; i < s_k * MMA_K_BF16; i += 128) {
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
int ck = c / 8, lc = c % 8, tmn = r / 8, lr = r % 8;
sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = sTmaBuf[i];
}
__syncthreads();
if (tid == 0) {
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN);
umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
}
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
// Softmax: row max
float my_row_max = -INFINITY;
if (my_warp_active) {
for (int n = 0; n < NUM_READS; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (my_row_active) {
for (int c = 0; c < 8; c++) {
int col = n * 8 + c;
if (col < s_k) my_row_max = fmaxf(my_row_max, tmp[c] * scale);
}
}
}
}
if (my_row_active) sRowMax[my_row] = my_row_max;
__syncthreads();
// Softmax: exp + sum + P
float my_p_vals[SK];
float my_row_sum = 0.0f;
if (my_warp_active) {
float rm = my_row_active ? sRowMax[my_row] : 0.0f;
for (int n = 0; n < NUM_READS; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (my_row_active) {
for (int c = 0; c < 8; c++) {
int col = n * 8 + c;
if (col < s_k) {
float p = expf(tmp[c] * scale - rm);
my_p_vals[col] = p;
my_row_sum += p;
}
}
}
}
}
if (my_row_active) sRowSum[my_row] = my_row_sum;
__syncthreads();
// Write P to GMEM
// Write P to GMEM — only active rows write
// Use threadIdx.x to avoid out-of-bounds
if (my_row_active) {
float inv_rs = 1.0f / sRowSum[my_row];
for (int j = lane; j < s_k; j += 32) {
out_p[my_row * s_k + j] = my_p_vals[j] * inv_rs;
}
}
if (wid == 0) tmem_dealloc(tb, TMEM_N);
}
int main() {
printf("QK+Softmax Test (HD=%d, SK=%d)\n", (int)HD, (int)SK);
const int T = 1;
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)calloc(T * HD, sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)calloc(SK * HD, sizeof(bf16_t));
srand(42);
for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
bf16_t *d_q, *d_k; float *d_out;
cudaMalloc(&d_q, T * HD * sizeof(bf16_t));
cudaMalloc(&d_k, SK * HD * sizeof(bf16_t));
cudaMalloc(&d_out, 128 * SK * sizeof(float));
cudaMemcpy(d_q, h_q, T * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
CUtensorMap tma_k; CUtensorMap* d_tma_k;
create_tma_desc_2d_bf16(&tma_k, d_k, (uint64_t)SK, (uint64_t)HD, 128, 16);
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
size_t smem = 4 + 128 + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + 16 + 8 + 128*sizeof(float)*2 + 256;
cudaFuncSetAttribute(test_qk_softmax_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem);
test_qk_softmax_kernel<<<1, 128, (int)smem>>>(d_out, d_q, d_tma_k, T, SK, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
float* h_out = (float*)malloc(128 * SK * sizeof(float));
cudaMemcpy(h_out, d_out, 128 * SK * sizeof(float), cudaMemcpyDeviceToHost);
// Reference: QK + softmax P values
int fail = 0; float max_rel = 0;
for (int t = 0; t < T; t++) {
float s[128], mx = -INFINITY;
for (int j = 0; j < SK; j++) {
float dot = 0;
for (int d = 0; d < HD; d++) dot += bf16_to_f32_host(h_q[t*HD+d]) * bf16_to_f32_host(h_k[j*HD+d]);
s[j] = dot * SCALE;
mx = fmaxf(mx, s[j]);
}
float sm = 0;
for (int j = 0; j < SK; j++) { s[j] = expf(s[j] - mx); sm += s[j]; }
for (int j = 0; j < SK; j++) {
float ref = s[j] / sm;
float got = h_out[t * SK + j];
float rel = fabsf(ref) > 1e-4f ? fabsf(got - ref) / fabsf(ref) : fabsf(got - ref);
if (rel > max_rel) max_rel = rel;
if (rel > 0.01f && fail < 5) printf(" t=%d j=%d: ref=%.6f got=%.6f\n", t, j, ref, got);
if (rel > 0.01f) fail++;
}
}
printf("Max rel err: %.8f, failures: %d\n", max_rel, fail);
printf("%s\n", fail == 0 ? "PASSED" : "FAILED");
return fail == 0 ? 0 : 1;
}