/** * DSV4 FMHA — QK GEMM verification with SWIZZLE_NONE UMMA layout. */ #pragma once #include "fmha_common.cuh" #include "fmha_umma_desc.cuh" #include namespace dsv4::kernels::attention { template __global__ void __launch_bounds__(NTHREADS) fmha_qk_verify( const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, float* __restrict__ s_out, int bstride_q, int bstride_kv, int s_k, float scale ) { const int head = blockIdx.y, batch = blockIdx.z, tid = threadIdx.x; const int wid = tid / WARP, lane = tid % WARP; const bf16_t* qh = q + batch*bstride_q + head*HD; const bf16_t* kb = k + batch*bstride_kv; // SMEM layout (256B aligned for UMMA): // [0..127] padding // [128..131] tmem_base // [132..255] padding // [256..256+128*HD*2) sQ (128×HD BF16 row-major) // [256+128*HD*2..) sK (128×HD BF16 row-major) extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)(sbuf + 128); bf16_t* sQ = (bf16_t*)(sbuf + 256); bf16_t* sK = sQ + 128 * HD; // Load Q and K to SMEM int kv_len = min(128, s_k); for (int i = tid; i < 128 * HD; i += NTHREADS) { int r = i / HD, c = i % HD; sQ[i] = (r == 0 && c < HD) ? qh[c] : 0; sK[i] = (r < kv_len) ? kb[r * HD + c] : 0; } __syncthreads(); // Sanity check: scalar QK dot product if (tid == 0) { float dot = 0; for (int d = 0; d < HD; d++) { dot += bf16_to_f32(sQ[d]) * bf16_to_f32(sK[d]); } s_out[0] = dot * scale; // row 0, col 0 (scalar reference) s_out[1] = bf16_to_f32(sQ[0]); // Q[0,0] s_out[2] = bf16_to_f32(sK[0]); // K[0,0] } // TMEM alloc if (wid == 0) { uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase); tmem_alloc(smem_ptr, 128); } __syncthreads(); uint32_t tmem_base = *sTmemBase; // Zero TMEM if (wid == 0) { for (int col = 0; col < 128; col++) { tmem_store(tmem_base + col, 0, 0, 0, 0); } tmem_fence_store(); } __syncthreads(); // UMMA descriptors (SWIZZLE_NONE with proper strides) uint32_t sQ_smem = __cvta_generic_to_shared(sQ); uint32_t sK_smem = __cvta_generic_to_shared(sK); // MN-major NONE for (128, 64) BF16: // LBO=16 (uint128_t), SBO=128 (uint128_t) uint64_t desc_q = 0; desc_q |= (static_cast(sQ_smem >> 4) & 0x3FFF); desc_q |= (static_cast(16) & 0x3FFF) << 16; desc_q |= (static_cast(128) & 0x3FFF) << 32; desc_q |= (static_cast(1) << 46); // K-major NONE for (128, 64) BF16: // LBO=16, SBO=32 uint64_t desc_k = 0; desc_k |= (static_cast(sK_smem >> 4) & 0x3FFF); desc_k |= (static_cast(16) & 0x3FFF) << 16; desc_k |= (static_cast(32) & 0x3FFF) << 32; desc_k |= (static_cast(1) << 46); // QK GEMM if (wid == 0) { umma_ss_f16(tmem_base, desc_q, desc_k, false); } __syncwarp(); if (wid == 0 && lane == 0) tmem_fence_store(); __syncthreads(); // Read S row 0 from TMEM if (tid == 0) { uint32_t u0, u1, u2, u3; tmem_load(tmem_base + 0, u0, u1, u2, u3); s_out[3] = u32_to_f32(u0) * scale; // MMA result: S[0,0] s_out[4] = u32_to_f32(u1) * scale; // S[0,1] s_out[5] = u32_to_f32(u2) * scale; // S[0,2] s_out[6] = u32_to_f32(u3) * scale; // S[0,3] } __syncthreads(); // Dealloc TMEM if (wid == 0) tmem_dealloc(tmem_base, 128); } } // namespace