Files
nvfp4-megamoe-kernel/tests/unit/qk_verify_kernel.cuh

118 lines
3.5 KiB
Plaintext
Raw Normal View History

/**
* DSV4 FMHA — QK GEMM verification with SWIZZLE_NONE UMMA layout.
*/
#pragma once
#include "fmha_common.cuh"
#include "fmha_umma_desc.cuh"
#include <cstdint>
namespace dsv4::kernels::attention {
template<int HD>
__global__ void __launch_bounds__(NTHREADS)
fmha_qk_verify(
const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
float* __restrict__ s_out,
int bstride_q, int bstride_kv,
int s_k, float scale
) {
const int head = blockIdx.y, batch = blockIdx.z, tid = threadIdx.x;
const int wid = tid / WARP, lane = tid % WARP;
const bf16_t* qh = q + batch*bstride_q + head*HD;
const bf16_t* kb = k + batch*bstride_kv;
// SMEM layout (256B aligned for UMMA):
// [0..127] padding
// [128..131] tmem_base
// [132..255] padding
// [256..256+128*HD*2) sQ (128×HD BF16 row-major)
// [256+128*HD*2..) sK (128×HD BF16 row-major)
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)(sbuf + 128);
bf16_t* sQ = (bf16_t*)(sbuf + 256);
bf16_t* sK = sQ + 128 * HD;
// Load Q and K to SMEM
int kv_len = min(128, s_k);
for (int i = tid; i < 128 * HD; i += NTHREADS) {
int r = i / HD, c = i % HD;
sQ[i] = (r == 0 && c < HD) ? qh[c] : 0;
sK[i] = (r < kv_len) ? kb[r * HD + c] : 0;
}
__syncthreads();
// Sanity check: scalar QK dot product
if (tid == 0) {
float dot = 0;
for (int d = 0; d < HD; d++) {
dot += bf16_to_f32(sQ[d]) * bf16_to_f32(sK[d]);
}
s_out[0] = dot * scale; // row 0, col 0 (scalar reference)
s_out[1] = bf16_to_f32(sQ[0]); // Q[0,0]
s_out[2] = bf16_to_f32(sK[0]); // K[0,0]
}
// TMEM alloc
if (wid == 0) {
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
tmem_alloc(smem_ptr, 128);
}
__syncthreads();
uint32_t tmem_base = *sTmemBase;
// Zero TMEM
if (wid == 0) {
for (int col = 0; col < 128; col++) {
tmem_store(tmem_base + col, 0, 0, 0, 0);
}
tmem_fence_store();
}
__syncthreads();
// UMMA descriptors (SWIZZLE_NONE with proper strides)
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
uint32_t sK_smem = __cvta_generic_to_shared(sK);
// MN-major NONE for (128, 64) BF16:
// LBO=16 (uint128_t), SBO=128 (uint128_t)
uint64_t desc_q = 0;
desc_q |= (static_cast<uint64_t>(sQ_smem >> 4) & 0x3FFF);
desc_q |= (static_cast<uint64_t>(16) & 0x3FFF) << 16;
desc_q |= (static_cast<uint64_t>(128) & 0x3FFF) << 32;
desc_q |= (static_cast<uint64_t>(1) << 46);
// K-major NONE for (128, 64) BF16:
// LBO=16, SBO=32
uint64_t desc_k = 0;
desc_k |= (static_cast<uint64_t>(sK_smem >> 4) & 0x3FFF);
desc_k |= (static_cast<uint64_t>(16) & 0x3FFF) << 16;
desc_k |= (static_cast<uint64_t>(32) & 0x3FFF) << 32;
desc_k |= (static_cast<uint64_t>(1) << 46);
// QK GEMM
if (wid == 0) {
umma_ss_f16(tmem_base, desc_q, desc_k, false);
}
__syncwarp();
if (wid == 0 && lane == 0) tmem_fence_store();
__syncthreads();
// Read S row 0 from TMEM
2026-05-28 08:38:08 +00:00
if (tid == 0) {
uint32_t u0, u1, u2, u3;
tmem_load(tmem_base + 0, u0, u1, u2, u3);
s_out[3] = u32_to_f32(u0) * scale; // MMA result: S[0,0]
s_out[4] = u32_to_f32(u1) * scale; // S[0,1]
s_out[5] = u32_to_f32(u2) * scale; // S[0,2]
s_out[6] = u32_to_f32(u3) * scale; // S[0,3]
}
__syncthreads();
// Dealloc TMEM
if (wid == 0) tmem_dealloc(tmem_base, 128);
}
} // namespace