Files
nvfp4-megamoe-kernel/tests/unit/test_softmax_pv.cu

175 lines
6.8 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Test: Softmax → PV (1 K-tile) via TS MMA.
*
* Write S (random data) to TMEM, softmax reads it, writes P to TMEM,
* then TS MMA does PV for 1 K-tile.
* This isolates the softmax→PV path from QK.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int SK = 128, HD = 16, BLOCK_MN = 128;
__global__ void __launch_bounds__(128)
test_softmax_pv(const float* __restrict__ s_input, // (SK,) — row 0 of S
const bf16_t* __restrict__ v, // (HD, SK)
float* __restrict__ o_mma,
float* __restrict__ o_ref,
float scale)
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
// Load V K-tile 0: (16, 16)
for (int i = tid; i < 256; i += 128) sV[i] = 0;
for (int d = tid; d < HD; d += 128) {
for (int lr = 0; lr < 16; lr++) {
int ck = d / 8, lc = d % 8;
int tmn = lr / 8, llr = lr % 8;
int dst_idx = ck * 2 * 64 + tmn * 64 + llr * 8 + lc;
sV[dst_idx] = v[d * SK + lr]; // First 16 positions of V
}
}
__syncthreads();
// TMEM alloc: 128 cols for S/P, 32 cols for O
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 256);
__syncthreads();
uint32_t tb = *sTmemBase;
uint32_t tb_o = tb + 128;
// Write S row 0 to TMEM directly (bypass QK, use input data)
// All 128 columns of S: for decode T=1, only row 0 has data
if (wid == 0) {
float s_vals[SK];
// Read S from GMEM (already scaled)
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] = s_input[j];
// Softmax
float row_max = -INFINITY;
if (lane == 0) for (int j=0;j<SK;j++) row_max = fmaxf(row_max, s_vals[j]);
row_max = wmax(row_max);
float row_sum = 0.0f;
if (lane == 0) for (int j=0;j<SK;j++) { s_vals[j] = expf(s_vals[j] - row_max); row_sum += s_vals[j]; }
row_sum = wsum(row_sum);
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
// Write P to TMEM (128 columns) using 32x32b.x8
for (int n = 0; n < SK / 8; n++) {
float p0=(lane==0)?s_vals[n*8+0]:0, p1=(lane==0)?s_vals[n*8+1]:0;
float p2=(lane==0)?s_vals[n*8+2]:0, p3=(lane==0)?s_vals[n*8+3]:0;
float p4=(lane==0)?s_vals[n*8+4]:0, p5=(lane==0)?s_vals[n*8+5]:0;
float p6=(lane==0)?s_vals[n*8+6]:0, p7=(lane==0)?s_vals[n*8+7]:0;
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};"
:: "r"(tb+n*8),
"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
}
tmem_fence_store();
}
__syncthreads();
// PV GEMM: 1 K-tile
// A = P cols 0-15 (TMEM), B = V rows 0-15 (SMEM), C = O (TMEM at tb_o)
{
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(128, 16);
if (tid == 0) umma_ts_f16(tb_o, tb, dv, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
// Read O row 0 from TMEM
if (wid == 0) {
float o_vals[HD];
for (int n = 0; n < HD / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb_o + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * 2.0f;
}
if (lane == 0) for (int d=0;d<HD;d++) o_mma[d] = o_vals[d];
}
__syncthreads();
// Scalar reference: O = P × V (first 16 positions of V)
if (tid == 0) {
float s[SK];
for (int j=0;j<SK;j++) s[j] = s_input[j];
float mx = -INFINITY;
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
float sm = 0.0f;
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
for (int j=0;j<SK;j++) s[j] /= sm;
for (int d=0;d<HD;d++) {
float ov = 0.0f;
for (int j=0;j<16;j++) ov += s[j] * bf16_to_f32(v[d*SK+j]); // Only first 16 positions
o_ref[d] = ov;
}
}
if (wid == 0) tmem_dealloc(tb, 256);
}
int main() {
printf("=== Softmax → PV (1 K-tile) ===\n");
const float SCALE = 1.0f / sqrtf((float)HD);
// Generate S = Q·K^T · scale (simulate QK output)
float* h_s = (float*)malloc(SK*sizeof(float));
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
float* h_o_mma = (float*)calloc(HD, sizeof(float));
float* h_o_ref = (float*)calloc(HD, sizeof(float));
srand(42);
// S values (simulate random attention scores)
for (int j=0;j<SK;j++) h_s[j] = ((float)(rand()%100)/100.0f - 0.5f) * 2.0f;
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
float *d_s; bf16_t *d_v; float *d_o_mma, *d_o_ref;
cudaMalloc(&d_s, SK*sizeof(float));
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
cudaMalloc(&d_o_mma, HD*sizeof(float));
cudaMalloc(&d_o_ref, HD*sizeof(float));
cudaMemcpy(d_s, h_s, SK*sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
int smem = (4+16 + 256*2 + 256 + 127) & ~127;
test_softmax_pv<<<1, 128, smem>>>(d_s, d_v, d_o_mma, d_o_ref, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
cudaMemcpy(h_o_mma, d_o_mma, HD*sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(h_o_ref, d_o_ref, HD*sizeof(float), cudaMemcpyDeviceToHost);
printf("O[0..15] MMA: "); for(int d=0;d<HD;d++) printf("%.6f ",h_o_mma[d]); printf("\n");
printf("O[0..15] ref: "); for(int d=0;d<HD;d++) printf("%.6f ",h_o_ref[d]); printf("\n");
float cos_sim=0,na=0,nb=0;
for (int d=0;d<HD;d++) { float a=h_o_mma[d],b=h_o_ref[d]; cos_sim+=a*b; na+=a*a; nb+=b*b; }
cos_sim /= (sqrtf(na)*sqrtf(nb)+1e-10f);
printf("cosine: %.8f\n", cos_sim);
printf("Test %s\n", cos_sim > 0.9f ? "PASSED" : "FAILED");
cudaFree(d_s); cudaFree(d_v); cudaFree(d_o_mma); cudaFree(d_o_ref);
free(h_s); free(h_v); free(h_o_mma); free(h_o_ref);
return cos_sim > 0.9f ? 0 : 1;
}