test: minimal QK — separate sQ0/sK0, clean SMEM layout
This commit is contained in:
165
tests/unit/test_qk_minimal.cu
Normal file
165
tests/unit/test_qk_minimal.cu
Normal file
@@ -0,0 +1,165 @@
|
||||
/**
|
||||
* Minimal QK test: load Q0 and K0 into SMEM, do one MMA, read result.
|
||||
* HD=64, NKT=4, T=1, SK=128.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#ifndef HD_VAL
|
||||
#define HD_VAL 64
|
||||
#endif
|
||||
|
||||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
||||
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
||||
|
||||
constexpr int HD = HD_VAL;
|
||||
constexpr int SK = 128;
|
||||
constexpr int NKT = HD / MMA_K_BF16;
|
||||
constexpr int CORES_MN = 16; // 128/8
|
||||
|
||||
__global__ void __launch_bounds__(192)
|
||||
test_qk_minimal_kernel(float* __restrict__ out_s,
|
||||
const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, int T, int s_k)
|
||||
{
|
||||
static constexpr int TILE_SZ = 128 * MMA_K_BF16;
|
||||
static constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
|
||||
static constexpr int NUM_READS = SK / 8;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane = tid % 32;
|
||||
const bool is_mma_warp = (wid == 4);
|
||||
|
||||
// SMEM: sQ0 and sK0 are (128, 16) each
|
||||
extern __shared__ __align__(128) char sbuf[];
|
||||
size_t off = 0;
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf; off = 4;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sQ0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sK0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
||||
|
||||
if (is_mma_warp) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
for (int kt = 0; kt < NKT; kt++) {
|
||||
// Load Q sub-tile
|
||||
for (int i = tid; i < TILE_SZ; i += NTHREADS) sQ0[i] = 0;
|
||||
for (int r = tid / 32; r < T; r += 6) { // one row per warp
|
||||
if (r < T) {
|
||||
for (int d = lane; d < MMA_K_BF16; d += 32) {
|
||||
int full_d = kt * MMA_K_BF16 + d;
|
||||
if (full_d < HD) {
|
||||
int ck = d/8, lc = d%8, cm = r/8, lr = r%8;
|
||||
sQ0[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = q[r * HD + full_d];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Load K sub-tile
|
||||
for (int i = tid; i < TILE_SZ; i += NTHREADS) sK0[i] = 0;
|
||||
for (int r = lane; r < s_k; r += 32) {
|
||||
for (int d = 0; d < MMA_K_BF16; d++) {
|
||||
int full_d = kt * MMA_K_BF16 + d;
|
||||
if (full_d < HD) {
|
||||
int ck = d/8, lc = d%8, cm = r/8, lr = r%8;
|
||||
sK0[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = k[r * HD + full_d];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128);
|
||||
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Read from TMEM
|
||||
if (wid == 0) {
|
||||
for (int n = 0; n < NUM_READS; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane < T) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
int col = n * 8 + c;
|
||||
if (col < s_k) out_s[lane * s_k + col] = tmp[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (is_mma_warp) tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("Minimal QK Test (HD=%d, SK=%d)\n", HD, SK);
|
||||
const int T = 1;
|
||||
|
||||
bf16_t* h_q = (bf16_t*)calloc(T * HD, sizeof(bf16_t));
|
||||
bf16_t* h_k = (bf16_t*)calloc(SK * HD, sizeof(bf16_t));
|
||||
srand(42);
|
||||
for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
|
||||
for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
|
||||
|
||||
bf16_t *d_q, *d_k; float *d_out;
|
||||
cudaMalloc(&d_q, T * HD * sizeof(bf16_t));
|
||||
cudaMalloc(&d_k, SK * HD * sizeof(bf16_t));
|
||||
cudaMalloc(&d_out, 128 * SK * sizeof(float));
|
||||
cudaMemcpy(d_q, h_q, T * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
int smem = 4 + 128 + TILE_SZ*2 + 4096;
|
||||
test_qk_minimal_kernel<<<1, 192, smem>>>(d_out, d_q, d_k, T, SK);
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||||
|
||||
float* h_out = (float*)malloc(128 * SK * sizeof(float));
|
||||
cudaMemcpy(h_out, d_out, 128 * SK * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
|
||||
float scale = 1.0f / sqrtf((float)HD);
|
||||
int fail = 0; float max_rel = 0;
|
||||
for (int t = 0; t < T; t++) {
|
||||
for (int j = 0; j < SK; j++) {
|
||||
float dot = 0;
|
||||
for (int d = 0; d < HD; d++)
|
||||
dot += bf16_to_f32_host(h_q[t * HD + d]) * bf16_to_f32_host(h_k[j * HD + d]);
|
||||
float ref = dot * scale;
|
||||
float got = h_out[t * SK + j];
|
||||
float rel = fabsf(ref) > 1e-4f ? fabsf(got - ref) / fabsf(ref) : fabsf(got - ref);
|
||||
if (rel > max_rel) max_rel = rel;
|
||||
if (rel > 0.01f && fail < 5) printf(" t=%d j=%d: ref=%.6f got=%.6f\n", t, j, ref, got);
|
||||
if (rel > 0.01f) fail++;
|
||||
}
|
||||
}
|
||||
printf("Max relative error: %.6f, failures: %d\n", max_rel, fail);
|
||||
printf("Raw output[0,0..4]: ");
|
||||
for (int j = 0; j < 5; j++) printf("%.6f ", h_out[j]);
|
||||
printf("\n");
|
||||
printf("%s\n", fail == 0 ? "PASSED" : "FAILED");
|
||||
return fail == 0 ? 0 : 1;
|
||||
}
|
||||
Reference in New Issue
Block a user