Files
nvfp4-megamoe-kernel/tests/unit/qk_verify_kernel.cuh
biondizzle 4b9eed02e1 Cleanup C1-C7: delete dead CuTeDSL FMHA, test probes, scratch files
- Deleted fmha.py (CuTeDSL slow path), FmhaKernel, Python KV merge
- Deleted fmha_sm100.cuh, fmha_sm100_tc.cuh, fmha_sm100_launch.cu, fmha_epilogue_sm100.cuh
- Moved fmha_qk_verify.cuh to tests/unit/qk_verify_kernel.cuh
- Deleted decode_sparse.py, decode_swa.py, kernels/decode/
- Deleted 46 test_d*.py probes, test_smem_*, test_cotiled_*, test_tmem_*,
  test_smem_p_*, test_ultra_minimal, test_fmha_pv16, test_working_softmax_maybe
- Deleted root scratch: debug_linear.py, test_mapping.py, run_router_tests.py
- Moved archive/ to archived_plans/code_archive/
- Rewrote production.py: single fast path via 6-warp multi-tile kernel
- Added STATUS.md, audit_attention_live.md
- Moved NEXT_PRIORITIES*.md to archived_plans/
2026-05-30 21:08:12 +00:00

118 lines
3.5 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* DSV4 FMHA — QK GEMM verification with SWIZZLE_NONE UMMA layout.
*/
#pragma once
#include "fmha_common.cuh"
#include "fmha_umma_desc.cuh"
#include <cstdint>
namespace dsv4::kernels::attention {
template<int HD>
__global__ void __launch_bounds__(NTHREADS)
fmha_qk_verify(
const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
float* __restrict__ s_out,
int bstride_q, int bstride_kv,
int s_k, float scale
) {
const int head = blockIdx.y, batch = blockIdx.z, tid = threadIdx.x;
const int wid = tid / WARP, lane = tid % WARP;
const bf16_t* qh = q + batch*bstride_q + head*HD;
const bf16_t* kb = k + batch*bstride_kv;
// SMEM layout (256B aligned for UMMA):
// [0..127] padding
// [128..131] tmem_base
// [132..255] padding
// [256..256+128*HD*2) sQ (128×HD BF16 row-major)
// [256+128*HD*2..) sK (128×HD BF16 row-major)
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)(sbuf + 128);
bf16_t* sQ = (bf16_t*)(sbuf + 256);
bf16_t* sK = sQ + 128 * HD;
// Load Q and K to SMEM
int kv_len = min(128, s_k);
for (int i = tid; i < 128 * HD; i += NTHREADS) {
int r = i / HD, c = i % HD;
sQ[i] = (r == 0 && c < HD) ? qh[c] : 0;
sK[i] = (r < kv_len) ? kb[r * HD + c] : 0;
}
__syncthreads();
// Sanity check: scalar QK dot product
if (tid == 0) {
float dot = 0;
for (int d = 0; d < HD; d++) {
dot += bf16_to_f32(sQ[d]) * bf16_to_f32(sK[d]);
}
s_out[0] = dot * scale; // row 0, col 0 (scalar reference)
s_out[1] = bf16_to_f32(sQ[0]); // Q[0,0]
s_out[2] = bf16_to_f32(sK[0]); // K[0,0]
}
// TMEM alloc
if (wid == 0) {
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
tmem_alloc(smem_ptr, 128);
}
__syncthreads();
uint32_t tmem_base = *sTmemBase;
// Zero TMEM
if (wid == 0) {
for (int col = 0; col < 128; col++) {
tmem_store(tmem_base + col, 0, 0, 0, 0);
}
tmem_fence_store();
}
__syncthreads();
// UMMA descriptors (SWIZZLE_NONE with proper strides)
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
uint32_t sK_smem = __cvta_generic_to_shared(sK);
// MN-major NONE for (128, 64) BF16:
// LBO=16 (uint128_t), SBO=128 (uint128_t)
uint64_t desc_q = 0;
desc_q |= (static_cast<uint64_t>(sQ_smem >> 4) & 0x3FFF);
desc_q |= (static_cast<uint64_t>(16) & 0x3FFF) << 16;
desc_q |= (static_cast<uint64_t>(128) & 0x3FFF) << 32;
desc_q |= (static_cast<uint64_t>(1) << 46);
// K-major NONE for (128, 64) BF16:
// LBO=16, SBO=32
uint64_t desc_k = 0;
desc_k |= (static_cast<uint64_t>(sK_smem >> 4) & 0x3FFF);
desc_k |= (static_cast<uint64_t>(16) & 0x3FFF) << 16;
desc_k |= (static_cast<uint64_t>(32) & 0x3FFF) << 32;
desc_k |= (static_cast<uint64_t>(1) << 46);
// QK GEMM
if (wid == 0) {
umma_ss_f16(tmem_base, desc_q, desc_k, false);
}
__syncwarp();
if (wid == 0 && lane == 0) tmem_fence_store();
__syncthreads();
// Read S row 0 from TMEM
if (tid == 0) {
uint32_t u0, u1, u2, u3;
tmem_load(tmem_base + 0, u0, u1, u2, u3);
s_out[3] = u32_to_f32(u0) * scale; // MMA result: S[0,0]
s_out[4] = u32_to_f32(u1) * scale; // S[0,1]
s_out[5] = u32_to_f32(u2) * scale; // S[0,2]
s_out[6] = u32_to_f32(u3) * scale; // S[0,3]
}
__syncthreads();
// Dealloc TMEM
if (wid == 0) tmem_dealloc(tmem_base, 128);
}
} // namespace