diff --git a/tests/unit/test_qk_direct.cu b/tests/unit/test_qk_direct.cu new file mode 100644 index 00000000..edc27718 --- /dev/null +++ b/tests/unit/test_qk_direct.cu @@ -0,0 +1,162 @@ +/** + * Direct comparison: load Q and K via TMA AND via direct GMEM reads, + * do MMA with both, compare TMEM output. + */ + +#include +#include +#include +#include +#include +#include + +#ifndef HD_VAL +#define HD_VAL 64 +#endif + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" +#include "dsv4/kernels/attention/fmha_tma.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = HD_VAL; +constexpr int SK = 128; +constexpr int NKT = HD / MMA_K_BF16; + +// Two kernels: one with TMA, one with direct loads +// Both do QK GEMM → TMEM → read out + +// === DIRECT LOAD VERSION (known working pattern from fmha_6warp_multirow) === +__global__ void __launch_bounds__(192) +test_qk_direct_kernel(float* __restrict__ out_s, + const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, int T, int s_k) +{ + static constexpr int TILE_SZ = 128 * MMA_K_BF16; + static constexpr int TMEM_N = (HD <= 128) ? 128 : 256; + + const int tid = threadIdx.x; + const int wid = tid / 32; + const bool is_mma_warp = (wid == 4); + const bool is_softmax_warp = (wid < 4); + + extern __shared__ __align__(128) char sbuf[]; + size_t off = 0; + uint32_t* sTmemBase = (uint32_t*)sbuf; off = 4; + off = (off + 127) & ~(size_t)127; + bf16_t* sQ = (bf16_t*)(sbuf + off); off += 128 * HD * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sK = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + + if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); + __syncthreads(); + uint32_t tb = *sTmemBase; + + // Direct load Q: row 0 only + write_q_to_smem(sQ, q); + __syncthreads(); + + for (int kt = 0; kt < NKT; kt++) { + // Load K sub-tile (128, 16) directly + for (int i = tid; i < TILE_SZ; i += NTHREADS) sK[i] = 0; + __syncthreads(); + // K in GMEM: (s_k, HD). Sub-tile at columns [kt*16, kt*16+16) + for (int i = tid; i < SK * MMA_K_BF16; i += NTHREADS) { + int r = i / MMA_K_BF16; + int c = i % MMA_K_BF16; + int gmem_c = kt * MMA_K_BF16 + c; + bf16_t val = k[r * HD + gmem_c]; + // Write to canonical + int core_mn = r / 8, core_k = c / 8; + int local_r = r % 8, local_c = c % 8; + int dst_idx = core_k * 16 * 64 + core_mn * 64 + local_r * 8 + local_c; + sK[dst_idx] = val; + } + __syncthreads(); + + if (is_mma_warp) { + uint32_t idesc = make_idesc(128, 128); + uint32_t sq_kt = (uint32_t)__cvta_generic_to_shared(sQ) + kt * 128 * 32; + uint64_t dq = make_umma_desc_kmajor_none(sq_kt, 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128); + if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + const bool my_warp_active = (wid == 0); + const int my_row = lane; + const bool my_row_active = my_row < T; + constexpr int NUM_READS = SK / 8; + + if (my_warp_active) { + for (int n = 0; n < NUM_READS; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (my_row_active) { + for (int c = 0; c < 8; c++) { + int col = n * 8 + c; + if (col < s_k) out_s[my_row * s_k + col] = tmp[c]; + } + } + } + } + __syncthreads(); + if (is_mma_warp) tmem_dealloc(tb, TMEM_N); +} + +int main() { + printf("Direct QK Test (HD=%d, SK=%d)\n", HD, SK); + const int T = 4; + + bf16_t* h_q = (bf16_t*)calloc(128 * HD, sizeof(bf16_t)); + bf16_t* h_k = (bf16_t*)calloc(SK * HD, sizeof(bf16_t)); + srand(42); + for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + + bf16_t *d_q, *d_k; float *d_out; + cudaMalloc(&d_q, 128 * HD * sizeof(bf16_t)); + cudaMalloc(&d_k, SK * HD * sizeof(bf16_t)); + cudaMalloc(&d_out, 128 * SK * sizeof(float)); + cudaMemcpy(d_q, h_q, 128 * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + + int smem = 4 + 128*HD*2 + 128*16*2 + 4096; + test_qk_direct_kernel<<<1, 192, smem>>>(d_out, d_q, d_k, T, SK); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + float* h_out = (float*)malloc(128 * SK * sizeof(float)); + cudaMemcpy(h_out, d_out, 128 * SK * sizeof(float), cudaMemcpyDeviceToHost); + + float scale = 1.0f / sqrtf((float)HD); + int fail = 0; float max_rel = 0; + for (int t = 0; t < T; t++) { + for (int j = 0; j < SK; j++) { + float dot = 0; + for (int d = 0; d < HD; d++) + dot += bf16_to_f32_host(h_q[t * HD + d]) * bf16_to_f32_host(h_k[j * HD + d]); + float ref = dot * scale; + float got = h_out[t * SK + j]; + float rel = fabsf(ref) > 1e-4f ? fabsf(got - ref) / fabsf(ref) : fabsf(got - ref); + if (rel > max_rel) max_rel = rel; + if (rel > 0.01f && fail < 3) printf(" t=%d j=%d: ref=%.6f got=%.6f rel=%.4f\n", t, j, ref, got, rel); + if (rel > 0.01f) fail++; + } + } + printf("Max relative error: %.6f, failures: %d\n", max_rel, fail); + printf("%s\n", fail == 0 ? "PASSED" : "FAILED"); + return fail == 0 ? 0 : 1; +}