2026-05-28 08:00:42 +00:00
|
|
|
|
/**
|
2026-05-28 08:35:30 +00:00
|
|
|
|
* DSV4 FMHA — QK GEMM verification with SWIZZLE_NONE UMMA layout.
|
2026-05-28 08:00:42 +00:00
|
|
|
|
*/
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
|
|
#include "fmha_common.cuh"
|
|
|
|
|
|
#include "fmha_umma_desc.cuh"
|
2026-05-28 08:02:53 +00:00
|
|
|
|
#include <cstdint>
|
2026-05-28 08:00:42 +00:00
|
|
|
|
|
|
|
|
|
|
namespace dsv4::kernels::attention {
|
|
|
|
|
|
|
|
|
|
|
|
template<int HD>
|
|
|
|
|
|
__global__ void __launch_bounds__(NTHREADS)
|
|
|
|
|
|
fmha_qk_verify(
|
|
|
|
|
|
const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
2026-05-28 08:35:30 +00:00
|
|
|
|
float* __restrict__ s_out,
|
2026-05-28 08:00:42 +00:00
|
|
|
|
int bstride_q, int bstride_kv,
|
|
|
|
|
|
int s_k, float scale
|
|
|
|
|
|
) {
|
|
|
|
|
|
const int head = blockIdx.y, batch = blockIdx.z, tid = threadIdx.x;
|
|
|
|
|
|
const int wid = tid / WARP, lane = tid % WARP;
|
|
|
|
|
|
|
|
|
|
|
|
const bf16_t* qh = q + batch*bstride_q + head*HD;
|
|
|
|
|
|
const bf16_t* kb = k + batch*bstride_kv;
|
|
|
|
|
|
|
2026-05-28 08:53:35 +00:00
|
|
|
|
// SMEM layout (256B aligned for UMMA):
|
|
|
|
|
|
// [0..127] padding
|
|
|
|
|
|
// [128..131] tmem_base
|
|
|
|
|
|
// [132..255] padding
|
|
|
|
|
|
// [256..256+128*HD*2) sQ (128×HD BF16 row-major)
|
|
|
|
|
|
// [256+128*HD*2..) sK (128×HD BF16 row-major)
|
2026-05-28 08:00:42 +00:00
|
|
|
|
extern __shared__ char sbuf[];
|
2026-05-28 08:46:56 +00:00
|
|
|
|
uint32_t* sTmemBase = (uint32_t*)(sbuf + 128);
|
|
|
|
|
|
bf16_t* sQ = (bf16_t*)(sbuf + 256);
|
2026-05-28 08:35:30 +00:00
|
|
|
|
bf16_t* sK = sQ + 128 * HD;
|
2026-05-28 08:18:47 +00:00
|
|
|
|
|
2026-05-28 08:53:35 +00:00
|
|
|
|
// Load Q and K to SMEM
|
2026-05-28 08:00:42 +00:00
|
|
|
|
int kv_len = min(128, s_k);
|
2026-05-28 08:35:30 +00:00
|
|
|
|
for (int i = tid; i < 128 * HD; i += NTHREADS) {
|
|
|
|
|
|
int r = i / HD, c = i % HD;
|
2026-05-28 08:53:35 +00:00
|
|
|
|
sQ[i] = (r == 0 && c < HD) ? qh[c] : 0;
|
2026-05-28 08:35:30 +00:00
|
|
|
|
sK[i] = (r < kv_len) ? kb[r * HD + c] : 0;
|
2026-05-28 08:00:42 +00:00
|
|
|
|
}
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
2026-05-28 08:53:35 +00:00
|
|
|
|
// Sanity check: scalar QK dot product
|
|
|
|
|
|
if (tid == 0) {
|
|
|
|
|
|
float dot = 0;
|
|
|
|
|
|
for (int d = 0; d < HD; d++) {
|
|
|
|
|
|
dot += bf16_to_f32(sQ[d]) * bf16_to_f32(sK[d]);
|
|
|
|
|
|
}
|
|
|
|
|
|
s_out[0] = dot * scale; // row 0, col 0 (scalar reference)
|
|
|
|
|
|
s_out[1] = bf16_to_f32(sQ[0]); // Q[0,0]
|
|
|
|
|
|
s_out[2] = bf16_to_f32(sK[0]); // K[0,0]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TMEM alloc
|
2026-05-28 08:00:42 +00:00
|
|
|
|
if (wid == 0) {
|
|
|
|
|
|
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
|
|
|
|
|
|
tmem_alloc(smem_ptr, 128);
|
|
|
|
|
|
}
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
uint32_t tmem_base = *sTmemBase;
|
|
|
|
|
|
|
2026-05-28 08:53:35 +00:00
|
|
|
|
// Zero TMEM
|
2026-05-28 08:00:42 +00:00
|
|
|
|
if (wid == 0) {
|
|
|
|
|
|
for (int col = 0; col < 128; col++) {
|
2026-05-28 08:35:30 +00:00
|
|
|
|
tmem_store(tmem_base + col, 0, 0, 0, 0);
|
2026-05-28 08:00:42 +00:00
|
|
|
|
}
|
|
|
|
|
|
tmem_fence_store();
|
|
|
|
|
|
}
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
2026-05-28 08:53:35 +00:00
|
|
|
|
// UMMA descriptors (SWIZZLE_NONE with proper strides)
|
2026-05-28 08:00:42 +00:00
|
|
|
|
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
|
|
|
|
|
|
uint32_t sK_smem = __cvta_generic_to_shared(sK);
|
|
|
|
|
|
|
2026-05-28 08:53:35 +00:00
|
|
|
|
// MN-major NONE for (128, 64) BF16:
|
|
|
|
|
|
// LBO=16 (uint128_t), SBO=128 (uint128_t)
|
2026-05-28 08:48:01 +00:00
|
|
|
|
uint64_t desc_q = 0;
|
2026-05-28 08:53:35 +00:00
|
|
|
|
desc_q |= (static_cast<uint64_t>(sQ_smem >> 4) & 0x3FFF);
|
|
|
|
|
|
desc_q |= (static_cast<uint64_t>(16) & 0x3FFF) << 16;
|
|
|
|
|
|
desc_q |= (static_cast<uint64_t>(128) & 0x3FFF) << 32;
|
|
|
|
|
|
desc_q |= (static_cast<uint64_t>(1) << 46);
|
2026-05-28 08:48:01 +00:00
|
|
|
|
|
2026-05-28 08:53:35 +00:00
|
|
|
|
// K-major NONE for (128, 64) BF16:
|
|
|
|
|
|
// LBO=16, SBO=32
|
2026-05-28 08:48:01 +00:00
|
|
|
|
uint64_t desc_k = 0;
|
|
|
|
|
|
desc_k |= (static_cast<uint64_t>(sK_smem >> 4) & 0x3FFF);
|
2026-05-28 08:53:35 +00:00
|
|
|
|
desc_k |= (static_cast<uint64_t>(16) & 0x3FFF) << 16;
|
|
|
|
|
|
desc_k |= (static_cast<uint64_t>(32) & 0x3FFF) << 32;
|
2026-05-28 08:48:01 +00:00
|
|
|
|
desc_k |= (static_cast<uint64_t>(1) << 46);
|
2026-05-28 08:03:53 +00:00
|
|
|
|
|
2026-05-28 08:53:35 +00:00
|
|
|
|
// QK GEMM
|
2026-05-28 08:37:02 +00:00
|
|
|
|
if (wid == 0) {
|
2026-05-28 08:35:30 +00:00
|
|
|
|
umma_ss_f16(tmem_base, desc_q, desc_k, false);
|
2026-05-28 08:00:42 +00:00
|
|
|
|
}
|
|
|
|
|
|
__syncwarp();
|
2026-05-28 08:35:30 +00:00
|
|
|
|
if (wid == 0 && lane == 0) tmem_fence_store();
|
2026-05-28 08:00:42 +00:00
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
2026-05-28 08:53:35 +00:00
|
|
|
|
// Read S row 0 from TMEM
|
2026-05-28 08:38:08 +00:00
|
|
|
|
if (tid == 0) {
|
2026-05-28 08:53:35 +00:00
|
|
|
|
uint32_t u0, u1, u2, u3;
|
|
|
|
|
|
tmem_load(tmem_base + 0, u0, u1, u2, u3);
|
|
|
|
|
|
s_out[3] = u32_to_f32(u0) * scale; // MMA result: S[0,0]
|
|
|
|
|
|
s_out[4] = u32_to_f32(u1) * scale; // S[0,1]
|
|
|
|
|
|
s_out[5] = u32_to_f32(u2) * scale; // S[0,2]
|
|
|
|
|
|
s_out[6] = u32_to_f32(u3) * scale; // S[0,3]
|
2026-05-28 08:00:42 +00:00
|
|
|
|
}
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
|
|
// Dealloc TMEM
|
2026-05-28 08:35:30 +00:00
|
|
|
|
if (wid == 0) tmem_dealloc(tmem_base, 128);
|
2026-05-28 08:00:42 +00:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace
|