Files
nvfp4-megamoe-kernel/tests/unit/test_fmha_v4.cu

251 lines
10 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Full FMHA HD=16, SK=128 — PV via SS MMA (SMEM-P approach)
*
* Pipeline: Q×K^T (SS) → softmax (TMEM→regs) → P×V (SS, per K-tile) → epilogue
*
* Key fix: write P per K-tile into a (128,16) canonical buffer from registers,
* instead of writing all of P to (128,128) canonical. This avoids the K-tile
* offset accumulation bug in the (128,128) layout.
*
* For decode T=1: only row 0 has data, so the (128,16) fill is trivial.
* For prefill T>1: all rows have data, fill from registers (still correct).
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = 16, SK = 128, BLOCK_MN = 128;
constexpr int NKT_QK = HD / MMA_K_BF16; // 1
constexpr int NKT_PV = SK / MMA_K_BF16; // 8
constexpr int TMEM_N = 128;
constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16;
__global__ void __launch_bounds__(128)
test_fmha_smem_p(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
const bf16_t* __restrict__ v, bf16_t* __restrict__ o_out,
float* __restrict__ o_scalar, float scale)
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK0 = sQ0 + TILE_SZ;
// sPk: (128, 16) reusable P K-tile buffer (canonical layout)
bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127);
// sV: 8 K-tiles of (16, 16) canonical
bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127);
// Load Q, K
write_q_to_smem<HD>(sQ0, q);
write_k_to_smem<SK, HD>(sK0, k);
// Load V K-tiles
for (int kt = 0; kt < NKT_PV; kt++) {
bf16_t* sv = sV + kt * 256;
for (int i = tid; i < 256; i += 128) sv[i] = 0;
for (int d = tid; d < HD; d += 128) {
for (int lr = 0; lr < MMA_K_BF16; lr++) {
int r = kt * MMA_K_BF16 + lr;
int ck = d / 8, lc = d % 8;
int tmn = lr / 8, llr = lr % 8;
int dst_idx = ck * 2 * 64 + tmn * 64 + llr * 8 + lc;
sv[dst_idx] = v[d * SK + r];
}
}
}
__syncthreads();
// TMEM alloc
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
__syncthreads();
uint32_t tb = *sTmemBase;
// ===== STEP 1: QK GEMM =====
{
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN);
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < NKT_QK; kt++) {
if (tid == 0) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
}
// ===== STEP 2: Softmax — read S from TMEM, keep P in registers =====
// Warp 0 reads S, computes softmax, stores P values in registers
// Then for each PV K-tile, warp 0 fills sPk from registers and signals
// s_p_vals lives in dynamic SMEM after sV
float* s_p_vals = (float*)(sV + NKT_PV * 256);
if (wid == 0) {
float s_vals[SK], row_max = -INFINITY;
for (int n = 0; n < SK / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) {
s_vals[n*8+c] = tmp[c] * scale;
row_max = fmaxf(row_max, tmp[c] * scale);
}
}
row_max = wmax(row_max);
float row_sum = 0.0f;
if (lane == 0) for (int j=0;j<SK;j++) {
s_vals[j] = expf(s_vals[j] - row_max);
row_sum += s_vals[j];
}
row_sum = wsum(row_sum);
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
// Store P to shared memory for other warps (and for K-tile fill)
if (lane == 0) for (int j=0;j<SK;j++) s_p_vals[j] = s_vals[j];
}
__syncthreads(); // Ensure s_p_vals is visible
// ===== STEP 3: PV GEMM (SS) — per K-tile =====
// For each PV K-tile: fill sPk with 16 P values, then call SS MMA
{
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(BLOCK_MN, HD);
for (int kt = 0; kt < NKT_PV; kt++) {
// Fill sPk (128, 16) canonical from s_p_vals[kt*16 .. kt*16+15]
// Zero sPk
for (int i = tid; i < TILE_SZ; i += 128) sPk[i] = 0;
// Write row 0: P[kt*16+c] for c=0..15
if (tid < 16) {
int c = tid;
int ck = c / 8, lc = c % 8;
int dst_idx = ck * 16 * 64 + 0 * 64 + 0 * 8 + lc;
sPk[dst_idx] = f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]);
}
__syncthreads();
// SS MMA: A = sPk (128,16), B = sV + kt*256 (16,16)
bf16_t* sv = sV + kt * 256;
uint64_t dv_kt = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), 16);
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), BLOCK_MN);
if (tid == 0) umma_ss_f16(tb, dp, dv_kt, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
}
// ===== STEP 4: Epilogue =====
if (wid == 0) {
float o_vals[HD];
for (int n = 0; n < HD / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
}
if (lane == 0) for (int d=0;d<HD;d++) o_out[d] = f32_to_bf16(o_vals[d]);
}
__syncthreads();
// Scalar reference
if (tid == 0) {
float s[SK];
for (int j=0;j<SK;j++) {
float dot = 0.0f;
for (int d=0;d<HD;d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j*HD+d]);
s[j] = dot * scale;
}
float mx = -INFINITY;
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
float sm = 0.0f;
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
for (int j=0;j<SK;j++) s[j] /= sm;
for (int d=0;d<HD;d++) {
float ov = 0.0f;
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32(v[d*SK+j]);
o_scalar[d] = ov;
}
}
if (wid == 0) tmem_dealloc(tb, TMEM_N);
}
int main() {
printf("=== Full FMHA HD=16 SMEM-P (per K-tile fill) ===\n");
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)malloc(HD*sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)malloc(SK*HD*sizeof(bf16_t));
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
bf16_t* h_o = (bf16_t*)calloc(HD, sizeof(bf16_t));
float* h_o_scalar = (float*)calloc(HD, sizeof(float));
srand(42);
for (int d=0;d<HD;d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<SK*HD;i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
bf16_t *d_q,*d_k,*d_v,*d_o; float *d_o_scalar;
cudaMalloc(&d_q, HD*sizeof(bf16_t));
cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
cudaMalloc(&d_o, HD*sizeof(bf16_t));
cudaMalloc(&d_o_scalar, HD*sizeof(float));
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
// SMEM: tmem(4+12) + sQ(4096) + sK(4096) + sPk(4096) + sV(4096) + s_p_vals(512) + align
int smem = (4+16 + TILE_SZ*2 + TILE_SZ*2 + NKT_PV*256*2 + SK*4 + 256 + 127) & ~127;
printf("SMEM: %d bytes (%.1f KB)\n", smem, smem/1024.0f);
test_fmha_smem_p<<<1, 128, smem>>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
cudaMemcpy(h_o, d_o, HD*sizeof(bf16_t), cudaMemcpyDeviceToHost);
cudaMemcpy(h_o_scalar, d_o_scalar, HD*sizeof(float), cudaMemcpyDeviceToHost);
printf("O[0..15] MMA: "); for(int d=0;d<HD;d++) printf("%.6f ",bf16_to_f32_host(h_o[d])); printf("\n");
printf("O[0..15] ref: "); for(int d=0;d<HD;d++) printf("%.6f ",h_o_scalar[d]); printf("\n");
// Compare with scale factor
float ratio_sum = 0; int ratio_count = 0;
for (int d=0;d<HD;d++) {
float mma_val = bf16_to_f32_host(h_o[d]);
float ref_val = h_o_scalar[d];
if (fabsf(ref_val) > 1e-6f) { ratio_sum += mma_val / ref_val; ratio_count++; }
}
float avg_ratio = ratio_count > 0 ? ratio_sum / ratio_count : 0;
printf("MMA/ref ratio: %.6f\n", avg_ratio);
float inv_scale = ratio_count > 0 ? 1.0f / avg_ratio : 1.0f;
float cos_sim=0,na=0,nb=0;
for (int d=0;d<HD;d++) {
float a=bf16_to_f32_host(h_o[d])*inv_scale, b=h_o_scalar[d];
cos_sim+=a*b; na+=a*a; nb+=b*b;
}
cos_sim /= (sqrtf(na)*sqrtf(nb)+1e-10f);
printf("After scale correction (÷%.4f): cosine = %.8f\n", avg_ratio, cos_sim);
printf("Test %s\n", cos_sim > 0.999f ? "PASSED" : "FAILED");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_o_scalar);
free(h_q); free(h_k); free(h_v); free(h_o); free(h_o_scalar);
return cos_sim > 0.999f ? 0 : 1;
}