230 lines
9.1 KiB
Plaintext
230 lines
9.1 KiB
Plaintext
/**
|
|
* Test QK + Softmax: load Q and K, compute QK GEMM, softmax, write P to GMEM.
|
|
* Verify P values against reference. No PV.
|
|
*/
|
|
|
|
#include <cuda_runtime.h>
|
|
#include <cuda.h>
|
|
#include <cstdio>
|
|
#include <cmath>
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
|
|
#ifndef HD_VAL
|
|
#define HD_VAL 64
|
|
#endif
|
|
|
|
#include "dsv4/kernels/attention/fmha_common.cuh"
|
|
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
|
#include "dsv4/kernels/attention/fmha_tma.cuh"
|
|
|
|
using namespace dsv4::kernels::attention;
|
|
|
|
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
|
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
|
|
|
constexpr int HD = HD_VAL;
|
|
constexpr int SK = 128;
|
|
constexpr int NKT = HD / MMA_K_BF16;
|
|
constexpr int BLOCK_MN = 128;
|
|
constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16;
|
|
constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
|
|
constexpr int CORES_MN = 16;
|
|
constexpr int NUM_READS = SK / 8;
|
|
|
|
__global__ void __launch_bounds__(128)
|
|
test_qk_softmax_kernel(
|
|
float* __restrict__ out_p, // (T, SK) — softmax P values
|
|
const bf16_t* __restrict__ q,
|
|
CUtensorMap* __restrict__ tma_k,
|
|
int T, int s_k, float scale
|
|
) {
|
|
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
|
|
|
extern __shared__ __align__(128) char sbuf[];
|
|
size_t off = 0;
|
|
uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off = 4;
|
|
off = (off + 127) & ~(size_t)127;
|
|
bf16_t* sQ0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
|
bf16_t* sK0 = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
|
bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
|
off = (off + 15) & ~(size_t)15;
|
|
uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8;
|
|
float* sRowMax = (float*)(sbuf + off); off += 128 * sizeof(float);
|
|
float* sRowSum = (float*)(sbuf + off); off += 128 * sizeof(float);
|
|
|
|
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
|
|
if (tid == 0) {
|
|
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
|
|
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
|
|
}
|
|
__syncthreads();
|
|
uint32_t tb = *sTmemBase;
|
|
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
|
int phase = 0;
|
|
|
|
const bool my_warp_active = (T <= 32) ? (wid == 0) : (wid < 4);
|
|
const int my_row = my_warp_active ? (wid * 32 + lane) : 0;
|
|
const bool my_row_active = my_warp_active && (my_row < T);
|
|
|
|
// QK GEMM
|
|
{
|
|
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
|
|
for (int kt = 0; kt < NKT; kt++) {
|
|
for (int i = tid; i < TILE_SZ; i += 128) sQ0[i] = 0;
|
|
for (int d = tid; d < T * MMA_K_BF16; d += 128) {
|
|
int r = d / MMA_K_BF16, c = d % MMA_K_BF16;
|
|
int full_d = kt * MMA_K_BF16 + c;
|
|
if (full_d < HD && r < T) {
|
|
int ck = c / 8, lc = c % 8, cm = r / 8, lr = r % 8;
|
|
sQ0[ck * CORES_MN * 64 + cm * 64 + lr * 8 + lc] = q[r * HD + full_d];
|
|
}
|
|
}
|
|
__syncthreads();
|
|
|
|
if (wid == 0 && lane == 0) {
|
|
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_k, mbar_addr, kt * MMA_K_BF16, 0);
|
|
tma_mbarrier_arrive_expect_tx(mbar_addr, TILE_SZ * sizeof(bf16_t));
|
|
}
|
|
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
|
|
__syncthreads();
|
|
|
|
for (int i = tid; i < TILE_SZ; i += 128) sK0[i] = 0;
|
|
for (int i = tid; i < s_k * MMA_K_BF16; i += 128) {
|
|
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
|
|
int ck = c / 8, lc = c % 8, tmn = r / 8, lr = r % 8;
|
|
sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = sTmaBuf[i];
|
|
}
|
|
__syncthreads();
|
|
|
|
if (tid == 0) {
|
|
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN);
|
|
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN);
|
|
umma_ss_f16(tb, dq, dk, idesc, kt > 0);
|
|
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
|
}
|
|
__syncthreads();
|
|
}
|
|
}
|
|
asm volatile("fence.sc.gpu;" ::: "memory");
|
|
__syncthreads();
|
|
|
|
// Softmax: row max
|
|
float my_row_max = -INFINITY;
|
|
if (my_warp_active) {
|
|
for (int n = 0; n < NUM_READS; n++) {
|
|
float tmp[8];
|
|
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
|
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
|
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
|
: "r"(tb + n * 8));
|
|
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
|
if (my_row_active) {
|
|
for (int c = 0; c < 8; c++) {
|
|
int col = n * 8 + c;
|
|
if (col < s_k) my_row_max = fmaxf(my_row_max, tmp[c] * scale);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (my_row_active) sRowMax[my_row] = my_row_max;
|
|
__syncthreads();
|
|
|
|
// Softmax: exp + sum + P
|
|
float my_p_vals[SK];
|
|
float my_row_sum = 0.0f;
|
|
if (my_warp_active) {
|
|
float rm = my_row_active ? sRowMax[my_row] : 0.0f;
|
|
for (int n = 0; n < NUM_READS; n++) {
|
|
float tmp[8];
|
|
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
|
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
|
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
|
: "r"(tb + n * 8));
|
|
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
|
if (my_row_active) {
|
|
for (int c = 0; c < 8; c++) {
|
|
int col = n * 8 + c;
|
|
if (col < s_k) {
|
|
float p = expf(tmp[c] * scale - rm);
|
|
my_p_vals[col] = p;
|
|
my_row_sum += p;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (my_row_active) sRowSum[my_row] = my_row_sum;
|
|
__syncthreads();
|
|
|
|
// Write P to GMEM
|
|
// Write P to GMEM — only active rows write
|
|
// Use threadIdx.x to avoid out-of-bounds
|
|
if (my_row_active) {
|
|
float inv_rs = 1.0f / sRowSum[my_row];
|
|
for (int j = lane; j < s_k; j += 32) {
|
|
out_p[my_row * s_k + j] = my_p_vals[j] * inv_rs;
|
|
}
|
|
}
|
|
|
|
if (wid == 0) tmem_dealloc(tb, TMEM_N);
|
|
}
|
|
|
|
int main() {
|
|
printf("QK+Softmax Test (HD=%d, SK=%d)\n", (int)HD, (int)SK);
|
|
const int T = 1;
|
|
const float SCALE = 1.0f / sqrtf((float)HD);
|
|
|
|
bf16_t* h_q = (bf16_t*)calloc(T * HD, sizeof(bf16_t));
|
|
bf16_t* h_k = (bf16_t*)calloc(SK * HD, sizeof(bf16_t));
|
|
srand(42);
|
|
for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
|
|
for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
|
|
|
|
bf16_t *d_q, *d_k; float *d_out;
|
|
cudaMalloc(&d_q, T * HD * sizeof(bf16_t));
|
|
cudaMalloc(&d_k, SK * HD * sizeof(bf16_t));
|
|
cudaMalloc(&d_out, 128 * SK * sizeof(float));
|
|
cudaMemcpy(d_q, h_q, T * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
|
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
|
|
|
CUtensorMap tma_k; CUtensorMap* d_tma_k;
|
|
create_tma_desc_2d_bf16(&tma_k, d_k, (uint64_t)SK, (uint64_t)HD, 128, 16);
|
|
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
|
|
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
|
|
|
size_t smem = 4 + 128 + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + TILE_SZ*sizeof(bf16_t) + 16 + 8 + 128*sizeof(float)*2 + 256;
|
|
cudaFuncSetAttribute(test_qk_softmax_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem);
|
|
test_qk_softmax_kernel<<<1, 128, (int)smem>>>(d_out, d_q, d_tma_k, T, SK, SCALE);
|
|
cudaError_t err = cudaDeviceSynchronize();
|
|
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
|
|
|
float* h_out = (float*)malloc(128 * SK * sizeof(float));
|
|
cudaMemcpy(h_out, d_out, 128 * SK * sizeof(float), cudaMemcpyDeviceToHost);
|
|
|
|
// Reference: QK + softmax P values
|
|
int fail = 0; float max_rel = 0;
|
|
for (int t = 0; t < T; t++) {
|
|
float s[128], mx = -INFINITY;
|
|
for (int j = 0; j < SK; j++) {
|
|
float dot = 0;
|
|
for (int d = 0; d < HD; d++) dot += bf16_to_f32_host(h_q[t*HD+d]) * bf16_to_f32_host(h_k[j*HD+d]);
|
|
s[j] = dot * SCALE;
|
|
mx = fmaxf(mx, s[j]);
|
|
}
|
|
float sm = 0;
|
|
for (int j = 0; j < SK; j++) { s[j] = expf(s[j] - mx); sm += s[j]; }
|
|
for (int j = 0; j < SK; j++) {
|
|
float ref = s[j] / sm;
|
|
float got = h_out[t * SK + j];
|
|
float rel = fabsf(ref) > 1e-4f ? fabsf(got - ref) / fabsf(ref) : fabsf(got - ref);
|
|
if (rel > max_rel) max_rel = rel;
|
|
if (rel > 0.01f && fail < 5) printf(" t=%d j=%d: ref=%.6f got=%.6f\n", t, j, ref, got);
|
|
if (rel > 0.01f) fail++;
|
|
}
|
|
}
|
|
printf("Max rel err: %.8f, failures: %d\n", max_rel, fail);
|
|
printf("%s\n", fail == 0 ? "PASSED" : "FAILED");
|
|
return fail == 0 ? 0 : 1;
|
|
}
|