Files
nvfp4-megamoe-kernel/tests/unit/test_fmha_softmax.cu

215 lines
8.8 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* UMMA FMHA Softmax Test — HD=64, SK=128
* Validates: QK GEMM → read S → softmax → write P → read P
* PV GEMM deferred to next test.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = 64, SK = 128, NKT = HD / MMA_K_BF16;
constexpr int BLOCK_MN = 128, TILE_SZ = BLOCK_MN * MMA_K_BF16, CORES_MN = BLOCK_MN / 8;
__global__ void __launch_bounds__(128)
test_softmax(const bf16_t* q, const bf16_t* k, bf16_t* p_out, float* p_scalar, float scale)
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK0 = sQ0 + NKT * TILE_SZ;
// Load Q and K (same as working QK test)
for (int i = tid; i < NKT * TILE_SZ; i += 128) { sQ0[i] = 0; sK0[i] = 0; }
for (int kt = 0; kt < NKT; kt++) {
bf16_t* sq = sQ0 + kt * TILE_SZ;
for (int d = tid; d < MMA_K_BF16; d += 128) {
int ck = d / 8, lc = d % 8;
sq[ck * CORES_MN * 64 + lc] = q[kt * MMA_K_BF16 + d];
}
bf16_t* sk = sK0 + kt * TILE_SZ;
for (int r = 0; r < SK; r++) {
for (int d = tid; d < MMA_K_BF16; d += 128) {
int ck = d / 8, lc = d % 8;
int tmn = r / 8, lr = r % 8;
sk[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * MMA_K_BF16 + d];
}
}
}
__syncthreads();
// TMEM alloc
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// QK GEMM
bf16_t* sQ_arr[4] = {sQ0, sQ0+TILE_SZ, sQ0+2*TILE_SZ, sQ0+3*TILE_SZ};
bf16_t* sK_arr[4] = {sK0, sK0+TILE_SZ, sK0+2*TILE_SZ, sK0+3*TILE_SZ};
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < NKT; kt++) {
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ_arr[kt]), BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK_arr[kt]), BLOCK_MN);
if (tid == 0) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// ================================================================
// SOFTMAX: Read row 0 of S, compute softmax, write P back to TMEM
// ================================================================
if (wid == 0) {
float s_vals[SK];
float row_max = -INFINITY;
// Read S row 0 from TMEM using 32x32b.x8
for (int n = 0; n < SK / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) {
for (int c = 0; c < 8; c++) {
// S is UNSCALED raw dot product; apply scale
s_vals[n * 8 + c] = tmp[c] * scale;
row_max = fmaxf(row_max, tmp[c] * scale);
}
}
}
row_max = wmax(row_max);
// exp(S - max) and sum
float row_sum = 0.0f;
if (lane == 0) {
for (int j = 0; j < SK; j++) {
s_vals[j] = expf(s_vals[j] - row_max);
row_sum += s_vals[j];
}
}
row_sum = wsum(row_sum);
// Normalize
if (lane == 0) {
for (int j = 0; j < SK; j++) s_vals[j] /= row_sum;
}
// Write P back to TMEM using 32x32b.x8 stores
// P is (128, 128) with only row 0 non-zero.
// 32x32b.x8: 32 rows × 8 columns. Lane 0 writes row 0, lanes 1-31 write 0.
for (int n = 0; n < SK / 8; n++) {
float p0 = (lane == 0) ? s_vals[n*8+0] : 0;
float p1 = (lane == 0) ? s_vals[n*8+1] : 0;
float p2 = (lane == 0) ? s_vals[n*8+2] : 0;
float p3 = (lane == 0) ? s_vals[n*8+3] : 0;
float p4 = (lane == 0) ? s_vals[n*8+4] : 0;
float p5 = (lane == 0) ? s_vals[n*8+5] : 0;
float p6 = (lane == 0) ? s_vals[n*8+6] : 0;
float p7 = (lane == 0) ? s_vals[n*8+7] : 0;
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb + n*8), "f"(p0), "f"(p1), "f"(p2), "f"(p3), "f"(p4), "f"(p5), "f"(p6), "f"(p7));
}
tmem_fence_store();
}
__syncthreads();
// Read P back from TMEM to verify
if (wid == 0) {
float p_vals[SK];
for (int n = 0; n < SK / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) {
for (int c = 0; c < 8; c++) p_vals[n * 8 + c] = tmp[c];
}
}
if (lane == 0) {
for (int j = 0; j < SK; j++) p_out[j] = f32_to_bf16(p_vals[j]);
}
}
__syncthreads();
// Scalar softmax reference
if (tid == 0) {
float s[SK];
for (int j = 0; j < SK; j++) {
float dot = 0.0f;
for (int d = 0; d < HD; d++)
dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * HD + d]);
s[j] = dot * scale;
}
float mx = -INFINITY;
for (int j = 0; j < SK; j++) mx = fmaxf(mx, s[j]);
float sm = 0.0f;
for (int j = 0; j < SK; j++) { s[j] = expf(s[j] - mx); sm += s[j]; }
for (int j = 0; j < SK; j++) p_scalar[j] = s[j] / sm;
}
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== UMMA FMHA Softmax HD=64 ===\n");
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)malloc(SK * HD * sizeof(bf16_t));
bf16_t* h_p = (bf16_t*)calloc(SK, sizeof(bf16_t));
float* h_p_scalar = (float*)calloc(SK, sizeof(float));
srand(42);
for (int d = 0; d < HD; d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
for (int i = 0; i < SK*HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
bf16_t *d_q, *d_k, *d_p; float *d_p_scalar;
cudaMalloc(&d_q, HD*sizeof(bf16_t));
cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
cudaMalloc(&d_p, SK*sizeof(bf16_t));
cudaMalloc(&d_p_scalar, SK*sizeof(float));
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
int smem = (4 + 16 + 2*NKT*TILE_SZ*sizeof(bf16_t) + 256 + 127) & ~127;
test_softmax<<<1, 128, smem>>>(d_q, d_k, d_p, d_p_scalar, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
cudaMemcpy(h_p, d_p, SK*sizeof(bf16_t), cudaMemcpyDeviceToHost);
cudaMemcpy(h_p_scalar, d_p_scalar, SK*sizeof(float), cudaMemcpyDeviceToHost);
printf("P[0,0..7] MMA: "); for(int j=0;j<8;j++) printf("%.6f ",bf16_to_f32_host(h_p[j])); printf("\n");
printf("P[0,0..7] ref: "); for(int j=0;j<8;j++) printf("%.6f ",h_p_scalar[j]); printf("\n");
printf("P[0,64..71] MMA: "); for(int j=64;j<72;j++) printf("%.6f ",bf16_to_f32_host(h_p[j])); printf("\n");
printf("P[0,64..71] ref: "); for(int j=64;j<72;j++) printf("%.6f ",h_p_scalar[j]); printf("\n");
float max_diff = 0.0f, max_val = 0.0f;
for (int j = 0; j < SK; j++) {
float diff = fabsf(bf16_to_f32_host(h_p[j]) - h_p_scalar[j]);
max_diff = fmaxf(max_diff, diff);
max_val = fmaxf(max_val, fabsf(h_p_scalar[j]));
}
float rel_err = max_val > 0 ? max_diff / max_val : max_diff;
// Also check sum ≈ 1.0
float p_sum = 0.0f;
for (int j = 0; j < SK; j++) p_sum += bf16_to_f32_host(h_p[j]);
printf("Row 0 max rel err: %.8f | sum: %.6f\n", rel_err, p_sum);
printf("Test %s\n", (rel_err < 0.01f && fabsf(p_sum - 1.0f) < 0.01f) ? "PASSED" : "FAILED");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_p); cudaFree(d_p_scalar);
free(h_q); free(h_k); free(h_p); free(h_p_scalar);
return (rel_err < 0.01f && fabsf(p_sum - 1.0f) < 0.01f) ? 0 : 1;
}