Files
nvfp4-megamoe-kernel/tests/unit/test_qk_pv_layout.cu

176 lines
7.0 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Debug test: QK (SS) → PV (TS) without softmax.
* This tests if the TS MMA can correctly read the SS MMA's output from TMEM.
* The TMEM layout of S (written by SS MMA) should be compatible with TS MMA's A format.
* If this produces reasonable results (not garbage), the layout is compatible.
* If garbage, the TMEM layouts don't match.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = 16, SK = 128, BLOCK_MN = 128;
constexpr int NKT_QK = HD / MMA_K_BF16; // 1
constexpr int NKT_PV = SK / MMA_K_BF16; // 8
constexpr int TMEM_P = 128;
constexpr int TMEM_N = 256;
constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16;
__global__ void __launch_bounds__(128)
test_qk_pv_no_softmax(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
const bf16_t* __restrict__ v, float* __restrict__ o_mma,
float* __restrict__ o_scalar, float scale)
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK0 = sQ0 + TILE_SZ;
bf16_t* sV = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127);
write_q_to_smem<HD>(sQ0, q);
write_k_to_smem<SK, HD>(sK0, k);
// Load V K-tiles
for (int kt = 0; kt < NKT_PV; kt++) {
bf16_t* sv = sV + kt * 256;
for (int i = tid; i < 256; i += 128) sv[i] = 0;
for (int d = tid; d < HD; d += 128) {
for (int lr = 0; lr < MMA_K_BF16; lr++) {
int r = kt * MMA_K_BF16 + lr;
int ck = d / 8, lc = d % 8;
int tmn = lr / 8, llr = lr % 8;
int dst_idx = ck * 2 * 64 + tmn * 64 + llr * 8 + lc;
sv[dst_idx] = v[d * SK + r];
}
}
}
__syncthreads();
// TMEM alloc: 256 cols. 0-127 = S. 128-143 = O.
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
__syncthreads();
uint32_t tb = *sTmemBase;
uint32_t tb_o = tb + TMEM_P;
// QK GEMM
{
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN);
uint32_t idesc_qk = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < NKT_QK; kt++) {
if (tid == 0) umma_ss_f16(tb, dq, dk, idesc_qk, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
}
// NO SOFTMAX — use S directly as P (wrong mathematically, but tests TMEM layout)
// PV GEMM: S × V → O
{
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD);
for (int kt = 0; kt < NKT_PV; kt++) {
uint32_t tmem_a = tb + kt * MMA_K_BF16;
bf16_t* sv = sV + kt * 256;
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), 16);
if (tid == 0) umma_ts_f16(tb_o, tmem_a, dv, idesc_pv, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
}
// Read O from TMEM (row 0 only for T=1)
if (wid == 0) {
float o_vals[HD];
for (int n = 0; n < HD / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb_o + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * 2.0f;
}
if (lane == 0) for (int d=0;d<HD;d++) o_mma[d] = o_vals[d];
}
__syncthreads();
// Scalar reference: O = (Q·K^T · scale) × V (no softmax)
if (tid == 0) {
float s[SK];
for (int j=0;j<SK;j++) {
float dot = 0.0f;
for (int d=0;d<HD;d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j*HD+d]);
s[j] = dot * scale;
}
for (int d=0;d<HD;d++) {
float ov = 0.0f;
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32(v[d*SK+j]);
o_scalar[d] = ov;
}
}
if (wid == 0) tmem_dealloc(tb, TMEM_N);
}
int main() {
printf("=== QK → PV (no softmax) — TMEM layout test ===\n");
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)malloc(HD*sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)malloc(SK*HD*sizeof(bf16_t));
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
float* h_o_mma = (float*)calloc(HD, sizeof(float));
float* h_o_scalar = (float*)calloc(HD, sizeof(float));
srand(42);
for (int d=0;d<HD;d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<SK*HD;i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
bf16_t *d_q,*d_k,*d_v; float *d_o_mma, *d_o_scalar;
cudaMalloc(&d_q, HD*sizeof(bf16_t));
cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
cudaMalloc(&d_o_mma, HD*sizeof(float));
cudaMalloc(&d_o_scalar, HD*sizeof(float));
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
int smem = (4+16 + TILE_SZ*2 + TILE_SZ*2 + NKT_PV*256*2 + 256 + 127) & ~127;
test_qk_pv_no_softmax<<<1, 128, smem>>>(d_q, d_k, d_v, d_o_mma, d_o_scalar, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
cudaMemcpy(h_o_mma, d_o_mma, HD*sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(h_o_scalar, d_o_scalar, HD*sizeof(float), cudaMemcpyDeviceToHost);
printf("O[0..15] MMA: "); for(int d=0;d<HD;d++) printf("%.6f ",h_o_mma[d]); printf("\n");
printf("O[0..15] ref: "); for(int d=0;d<HD;d++) printf("%.6f ",h_o_scalar[d]); printf("\n");
float cos_sim=0,na=0,nb=0;
for (int d=0;d<HD;d++) { float a=h_o_mma[d],b=h_o_scalar[d]; cos_sim+=a*b; na+=a*a; nb+=b*b; }
cos_sim /= (sqrtf(na)*sqrtf(nb)+1e-10f);
printf("cosine: %.8f\n", cos_sim);
printf("Test %s\n", cos_sim > 0.999f ? "PASSED" : "FAILED");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o_mma); cudaFree(d_o_scalar);
free(h_q); free(h_k); free(h_v); free(h_o_mma); free(h_o_scalar);
return cos_sim > 0.999f ? 0 : 1;
}