- Deleted fmha.py (CuTeDSL slow path), FmhaKernel, Python KV merge - Deleted fmha_sm100.cuh, fmha_sm100_tc.cuh, fmha_sm100_launch.cu, fmha_epilogue_sm100.cuh - Moved fmha_qk_verify.cuh to tests/unit/qk_verify_kernel.cuh - Deleted decode_sparse.py, decode_swa.py, kernels/decode/ - Deleted 46 test_d*.py probes, test_smem_*, test_cotiled_*, test_tmem_*, test_smem_p_*, test_ultra_minimal, test_fmha_pv16, test_working_softmax_maybe - Deleted root scratch: debug_linear.py, test_mapping.py, run_router_tests.py - Moved archive/ to archived_plans/code_archive/ - Rewrote production.py: single fast path via 6-warp multi-tile kernel - Added STATUS.md, audit_attention_live.md - Moved NEXT_PRIORITIES*.md to archived_plans/
118 lines
3.5 KiB
Plaintext
118 lines
3.5 KiB
Plaintext
/**
|
||
* DSV4 FMHA — QK GEMM verification with SWIZZLE_NONE UMMA layout.
|
||
*/
|
||
#pragma once
|
||
|
||
#include "fmha_common.cuh"
|
||
#include "fmha_umma_desc.cuh"
|
||
#include <cstdint>
|
||
|
||
namespace dsv4::kernels::attention {
|
||
|
||
template<int HD>
|
||
__global__ void __launch_bounds__(NTHREADS)
|
||
fmha_qk_verify(
|
||
const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||
float* __restrict__ s_out,
|
||
int bstride_q, int bstride_kv,
|
||
int s_k, float scale
|
||
) {
|
||
const int head = blockIdx.y, batch = blockIdx.z, tid = threadIdx.x;
|
||
const int wid = tid / WARP, lane = tid % WARP;
|
||
|
||
const bf16_t* qh = q + batch*bstride_q + head*HD;
|
||
const bf16_t* kb = k + batch*bstride_kv;
|
||
|
||
// SMEM layout (256B aligned for UMMA):
|
||
// [0..127] padding
|
||
// [128..131] tmem_base
|
||
// [132..255] padding
|
||
// [256..256+128*HD*2) sQ (128×HD BF16 row-major)
|
||
// [256+128*HD*2..) sK (128×HD BF16 row-major)
|
||
extern __shared__ char sbuf[];
|
||
uint32_t* sTmemBase = (uint32_t*)(sbuf + 128);
|
||
bf16_t* sQ = (bf16_t*)(sbuf + 256);
|
||
bf16_t* sK = sQ + 128 * HD;
|
||
|
||
// Load Q and K to SMEM
|
||
int kv_len = min(128, s_k);
|
||
for (int i = tid; i < 128 * HD; i += NTHREADS) {
|
||
int r = i / HD, c = i % HD;
|
||
sQ[i] = (r == 0 && c < HD) ? qh[c] : 0;
|
||
sK[i] = (r < kv_len) ? kb[r * HD + c] : 0;
|
||
}
|
||
__syncthreads();
|
||
|
||
// Sanity check: scalar QK dot product
|
||
if (tid == 0) {
|
||
float dot = 0;
|
||
for (int d = 0; d < HD; d++) {
|
||
dot += bf16_to_f32(sQ[d]) * bf16_to_f32(sK[d]);
|
||
}
|
||
s_out[0] = dot * scale; // row 0, col 0 (scalar reference)
|
||
s_out[1] = bf16_to_f32(sQ[0]); // Q[0,0]
|
||
s_out[2] = bf16_to_f32(sK[0]); // K[0,0]
|
||
}
|
||
|
||
// TMEM alloc
|
||
if (wid == 0) {
|
||
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
|
||
tmem_alloc(smem_ptr, 128);
|
||
}
|
||
__syncthreads();
|
||
uint32_t tmem_base = *sTmemBase;
|
||
|
||
// Zero TMEM
|
||
if (wid == 0) {
|
||
for (int col = 0; col < 128; col++) {
|
||
tmem_store(tmem_base + col, 0, 0, 0, 0);
|
||
}
|
||
tmem_fence_store();
|
||
}
|
||
__syncthreads();
|
||
|
||
// UMMA descriptors (SWIZZLE_NONE with proper strides)
|
||
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
|
||
uint32_t sK_smem = __cvta_generic_to_shared(sK);
|
||
|
||
// MN-major NONE for (128, 64) BF16:
|
||
// LBO=16 (uint128_t), SBO=128 (uint128_t)
|
||
uint64_t desc_q = 0;
|
||
desc_q |= (static_cast<uint64_t>(sQ_smem >> 4) & 0x3FFF);
|
||
desc_q |= (static_cast<uint64_t>(16) & 0x3FFF) << 16;
|
||
desc_q |= (static_cast<uint64_t>(128) & 0x3FFF) << 32;
|
||
desc_q |= (static_cast<uint64_t>(1) << 46);
|
||
|
||
// K-major NONE for (128, 64) BF16:
|
||
// LBO=16, SBO=32
|
||
uint64_t desc_k = 0;
|
||
desc_k |= (static_cast<uint64_t>(sK_smem >> 4) & 0x3FFF);
|
||
desc_k |= (static_cast<uint64_t>(16) & 0x3FFF) << 16;
|
||
desc_k |= (static_cast<uint64_t>(32) & 0x3FFF) << 32;
|
||
desc_k |= (static_cast<uint64_t>(1) << 46);
|
||
|
||
// QK GEMM
|
||
if (wid == 0) {
|
||
umma_ss_f16(tmem_base, desc_q, desc_k, false);
|
||
}
|
||
__syncwarp();
|
||
if (wid == 0 && lane == 0) tmem_fence_store();
|
||
__syncthreads();
|
||
|
||
// Read S row 0 from TMEM
|
||
if (tid == 0) {
|
||
uint32_t u0, u1, u2, u3;
|
||
tmem_load(tmem_base + 0, u0, u1, u2, u3);
|
||
s_out[3] = u32_to_f32(u0) * scale; // MMA result: S[0,0]
|
||
s_out[4] = u32_to_f32(u1) * scale; // S[0,1]
|
||
s_out[5] = u32_to_f32(u2) * scale; // S[0,2]
|
||
s_out[6] = u32_to_f32(u3) * scale; // S[0,3]
|
||
}
|
||
__syncthreads();
|
||
|
||
// Dealloc TMEM
|
||
if (wid == 0) tmem_dealloc(tmem_base, 128);
|
||
}
|
||
|
||
} // namespace
|