From fd6a9b00aeeadafdc94808639aa7908097355f90 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 29 May 2026 19:10:08 +0000 Subject: [PATCH] =?UTF-8?q?test:=20QK=20+=20softmax=20=E2=80=94=20verify?= =?UTF-8?q?=20P=20values=20against=20reference?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_qk_softmax.cu | 227 ++++++++++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 tests/unit/test_qk_softmax.cu diff --git a/tests/unit/test_qk_softmax.cu b/tests/unit/test_qk_softmax.cu new file mode 100644 index 00000000..bb0e51ad --- /dev/null +++ b/tests/unit/test_qk_softmax.cu @@ -0,0 +1,227 @@ +/** + * Test QK + Softmax: load Q and K, compute QK GEMM, softmax, write P to GMEM. + * Verify P values against reference. No PV. + */ + +#include +#include +#include +#include +#include +#include + +#ifndef HD_VAL +#define HD_VAL 64 +#endif + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" +#include "dsv4/kernels/attention/fmha_tma.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = HD_VAL; +constexpr int SK = 128; +constexpr int NKT = HD / MMA_K_BF16; +constexpr int BLOCK_MN = 128; +constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; +constexpr int TMEM_N = (HD <= 128) ? 128 : 256; +constexpr int CORES_MN = 16; +constexpr int NUM_READS = SK / 8; + +__global__ void __launch_bounds__(128) +test_qk_softmax_kernel( + float* __restrict__ out_p, // (T, SK) — softmax P values + const bf16_t* __restrict__ q, + CUtensorMap* __restrict__ tma_k, + int T, int s_k, float scale +) { + const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; + + extern __shared__ __align__(128) char sbuf[]; + size_t off = 0; + uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off = 4; + off = (off + 127) & ~(size_t)127; + bf16_t* sQ0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + bf16_t* sK0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + off = (off + 15) & ~(size_t)15; + uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8; + float* sRowMax = (float*)(sbuf + off); off += 128 * sizeof(float); + float* sRowSum = (float*)(sbuf + off); off += 128 * sizeof(float); + + if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); + if (tid == 0) { + tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1); + asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); + } + __syncthreads(); + uint32_t tb = *sTmemBase; + const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); + int phase = 0; + + const bool my_warp_active = (T <= 32) ? (wid == 0) : (wid < 4); + const int my_row = my_warp_active ? (wid * 32 + lane) : 0; + const bool my_row_active = my_warp_active && (my_row < T); + + // QK GEMM + { + uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN); + for (int kt = 0; kt < NKT; kt++) { + for (int i = tid; i < TILE_SZ; i += 128) sQ0[i] = 0; + for (int d = tid; d < T * MMA_K_BF16; d += 128) { + int r = d / MMA_K_BF16, c = d % MMA_K_BF16; + int full_d = kt * MMA_K_BF16 + c; + if (full_d < HD && r < T) { + int ck = c / 8, lc = c % 8, cm = r / 8, lr = r % 8; + sQ0[ck * CORES_MN * 64 + cm * 64 + lr * 8 + lc] = q[r * HD + full_d]; + } + } + __syncthreads(); + + if (wid == 0 && lane == 0) { + tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k, mbar_addr, kt * MMA_K_BF16, 0); + tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t)); + } + tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; + __syncthreads(); + + for (int i = tid; i < TILE_SZ; i += 128) sK0[i] = 0; + for (int i = tid; i < s_k * MMA_K_BF16; i += 128) { + int r = i / MMA_K_BF16, c = i % MMA_K_BF16; + int ck = c / 8, lc = c % 8, tmn = r / 8, lr = r % 8; + sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = sTmaBuf[i]; + } + __syncthreads(); + + if (tid == 0) { + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN); + umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + } + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // Softmax: row max + float my_row_max = -INFINITY; + if (my_warp_active) { + for (int n = 0; n < NUM_READS; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (my_row_active) { + for (int c = 0; c < 8; c++) { + int col = n * 8 + c; + if (col < s_k) my_row_max = fmaxf(my_row_max, tmp[c] * scale); + } + } + } + } + if (my_row_active) sRowMax[my_row] = my_row_max; + __syncthreads(); + + // Softmax: exp + sum + P + float my_p_vals[SK]; + float my_row_sum = 0.0f; + if (my_warp_active) { + float rm = my_row_active ? sRowMax[my_row] : 0.0f; + for (int n = 0; n < NUM_READS; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (my_row_active) { + for (int c = 0; c < 8; c++) { + int col = n * 8 + c; + if (col < s_k) { + float p = expf(tmp[c] * scale - rm); + my_p_vals[col] = p; + my_row_sum += p; + } + } + } + } + } + if (my_row_active) sRowSum[my_row] = my_row_sum; + __syncthreads(); + + // Write P to GMEM + if (my_row_active) { + float inv_rs = 1.0f / my_row_sum; + for (int j = 0; j < s_k; j++) { + out_p[my_row * s_k + j] = my_p_vals[j] * inv_rs; + } + } + + if (wid == 0) tmem_dealloc(tb, TMEM_N); +} + +int main() { + printf("QK+Softmax Test (HD=%d, SK=%d)\n", (int)HD, (int)SK); + const int T = 4; + const float SCALE = 1.0f / sqrtf((float)HD); + + bf16_t* h_q = (bf16_t*)calloc(T * HD, sizeof(bf16_t)); + bf16_t* h_k = (bf16_t*)calloc(SK * HD, sizeof(bf16_t)); + srand(42); + for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + + bf16_t *d_q, *d_k; float *d_out; + cudaMalloc(&d_q, T * HD * sizeof(bf16_t)); + cudaMalloc(&d_k, SK * HD * sizeof(bf16_t)); + cudaMalloc(&d_out, 128 * SK * sizeof(float)); + cudaMemcpy(d_q, h_q, T * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + + CUtensorMap tma_k; CUtensorMap* d_tma_k; + create_tma_desc_2d_bf16(&tma_k, d_k, (uint64_t)SK, (uint64_t)HD, 128, 16); + cudaMalloc(&d_tma_k, sizeof(CUtensorMap)); + cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + + size_t smem = 4 + 128 + TILE_SZ + TILE_SZ + TILE_SZ + 16 + 8 + 128*4*2 + 256; + cudaFuncSetAttribute(test_qk_softmax_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem); + test_qk_softmax_kernel<<<1, 128, (int)smem>>>(d_out, d_q, d_tma_k, T, SK, SCALE); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + float* h_out = (float*)malloc(128 * SK * sizeof(float)); + cudaMemcpy(h_out, d_out, 128 * SK * sizeof(float), cudaMemcpyDeviceToHost); + + // Reference: QK + softmax P values + int fail = 0; float max_rel = 0; + for (int t = 0; t < T; t++) { + float s[128], mx = -INFINITY; + for (int j = 0; j < SK; j++) { + float dot = 0; + for (int d = 0; d < HD; d++) dot += bf16_to_f32_host(h_q[t*HD+d]) * bf16_to_f32_host(h_k[j*HD+d]); + s[j] = dot * SCALE; + mx = fmaxf(mx, s[j]); + } + float sm = 0; + for (int j = 0; j < SK; j++) { s[j] = expf(s[j] - mx); sm += s[j]; } + for (int j = 0; j < SK; j++) { + float ref = s[j] / sm; + float got = h_out[t * SK + j]; + float rel = fabsf(ref) > 1e-4f ? fabsf(got - ref) / fabsf(ref) : fabsf(got - ref); + if (rel > max_rel) max_rel = rel; + if (rel > 0.01f && fail < 5) printf(" t=%d j=%d: ref=%.6f got=%.6f\n", t, j, ref, got); + if (rel > 0.01f) fail++; + } + } + printf("Max rel err: %.8f, failures: %d\n", max_rel, fail); + printf("%s\n", fail == 0 ? "PASSED" : "FAILED"); + return fail == 0 ? 0 : 1; +}