test: UMMA QK HD=64 (4 K-tiles, accumulate) — multi-K-tile test

This commit is contained in:
2026-05-28 11:42:29 +00:00
parent df34cae9c6
commit 73f9ff98c9

View File

@@ -0,0 +1,231 @@
/**
* UMMA QK GEMM Test — HD=64 (4 K-tiles), SK=128
* Multi-K-tile accumulate: call MMA 4× with accumulate=true
* Each K-tile: 16 BF16 columns, separate descriptor
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
// MMA K-tile = 16 BF16
constexpr int MMA_K = 16;
constexpr int N_WARPS = 4;
constexpr int BLOCK_M = 128;
__global__ void __launch_bounds__(N_WARPS * 32)
test_umma_qk_hd64(const bf16_t* q, const bf16_t* k,
float* s_out, float* s_scalar, float scale, int hd, int sk)
{
const int tid = threadIdx.x;
const int wid = tid / 32, lane = tid % 32;
const int n_ktiles = hd / MMA_K; // 4 for hd=64
// SMEM: sQ (128, HD) canonical + sK (128, HD) canonical
// Each K-tile of (128, 16) = 4096 bytes
// Full (128, HD) = n_ktiles * 4096 bytes
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK = sQ + 128 * hd;
// TMEM alloc (128 cols for N=128)
if (wid == 1) {
tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
}
__syncthreads();
uint32_t tb = *sTmemBase;
// Load Q and K into SMEM in canonical layout
// Using the template with HD as a parameter
// write_q_to_smem and write_k_to_smem need to work with hd=64
// For now, use explicit loops
// Zero all first
for (int i = tid; i < 128 * hd; i += N_WARPS * 32) {
sQ[i] = 0;
sK[i] = 0;
}
__syncthreads();
// Write Q (1, hd) to row 0 of sQ in canonical layout
// Canonical: core(g, c) at offset c * 16 * 64 + g * 64 + local_r * 8 + local_c
for (int d = tid; d < hd; d += N_WARPS * 32) {
int core_k = d / 8, local_c = d % 8;
int idx = core_k * 16 * 64 + local_c; // tile_mn=0, local_r=0
sQ[idx] = q[d];
}
// Write K (sk, hd) to sK in canonical layout
for (int i = tid; i < sk * hd; i += N_WARPS * 32) {
int r = i / hd, c = i % hd;
int tile_mn = r / 8, core_k = c / 8;
int local_r = r % 8, local_c = c % 8;
int idx = core_k * 16 * 64 + tile_mn * 64 + local_r * 8 + local_c;
sK[idx] = k[i];
}
__syncthreads();
// Construct base descriptors for Q and K
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
uint32_t sK_smem = __cvta_generic_to_shared(sK);
// For each K-tile, construct a descriptor pointing to that 16-column slice
// K-tile k: columns [16k, 16k+16) of the (128, hd) matrix
// In canonical layout, the k-th 16-column slice starts at:
// k * CORES_MN * 64 = k * 16 * 64 = k * 1024 BF16 = k * 2048 bytes
// Each K-tile has BLOCK_M=128 rows and 16 columns.
// The descriptor for K-tile k: start_addr = sQ_smem + k * 2048
// But wait — gau-nernst uses A_smem + k * BLOCK_M * 32 for the start address
// BLOCK_M * 32 = 128 * 32 = 4096 bytes. But our K-tile is only 2048 bytes.
// Actually, gau-nernst's offset is for the SMEM start of the K-tile.
// In his layout, each (BLOCK_M, 32B) slice is BLOCK_M * 32 = 4096 bytes apart.
// But 32B = 16 BF16 = one K-tile. And his start_address = A_smem + k * BLOCK_M * 32.
// Wait — that's 4096 bytes per K-tile, but our (128, 16) matrix is only 4096 bytes.
// In canonical layout, the K-tile at columns [16k, 16k+16) starts at:
// The 2 core-matrix columns (c=2k and c=2k+1) are at offsets 2k*2048 and (2k+1)*2048.
// The full K-tile spans both columns: core(0..15, 2k) and core(0..15, 2k+1).
// The first column starts at 2k * 2048 and the second at (2k+1) * 2048.
// Total span: from 2k*2048 to (2k+1)*2048 + 16*128 = (2k+1)*2048 + 2048 = (2k+2)*2048.
// Hmm, this is getting complicated. The descriptor for a K-tile should describe
// a (128, 16) matrix starting at the right offset in SMEM.
// For K-tile 0 (columns 0-15): start at sQ_smem, LBO=2048, SBO=128
// For K-tile 1 (columns 16-31): start at sQ_smem + 2*2048, LBO=2048, SBO=128
// Wait, but columns 16-23 are core_k=2 and columns 24-31 are core_k=3.
// The K-tile at columns [16k, 16k+16) has core_k = 2k and 2k+1.
// core_k=2 starts at 2 * 1024 = 2048 BF16 = 4096 bytes from sQ.
// core_k=3 starts at 3 * 1024 = 3072 BF16 = 6144 bytes from sQ.
// But the descriptor's start_address is for the BEGINNING of the (128, 16) tile.
// The descriptor with LBO=2048 walks: column 0 at start, column 1 at start+2048.
// So for K-tile k, we need:
// start = sQ_smem + (2k) * 1024 * 2 = sQ_smem + 2k * 2048 bytes
// Column 0 = core_k 2k, column 1 = core_k 2k+1
// LBO = 2048 (same as before)
// SBO = 128 (same as before)
// Actually, this IS the same descriptor but with a different start_address.
// The descriptor for K-tile k: start = sQ + 2k * 2048
uint32_t idesc = make_idesc(BLOCK_M, 128);
// K-tile loop with accumulate
for (int kt = 0; kt < n_ktiles; kt++) {
// Descriptor for Q's k-th K-tile
uint32_t q_ktile_addr = sQ_smem + kt * 2 * 2048; // 2 core-matrix columns per K-tile
uint32_t k_ktile_addr = sK_smem + kt * 2 * 2048;
uint64_t desc_q = make_umma_desc_kmajor_none(q_ktile_addr, BLOCK_M);
uint64_t desc_k = make_umma_desc_kmajor_none(k_ktile_addr, BLOCK_M);
bool accumulate = (kt > 0);
// 4 warp leaders call MMA
if (lane == 0) {
umma_ss_f16(tb, desc_q, desc_k, idesc, accumulate);
}
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
// Read from TMEM: 32x32b.x8, each warp reads 32 rows × 8 columns
for (int n = 0; n < 128 / 8; n++) {
const int row = wid * 32;
const int col = n * 8;
const int addr = tb + (row << 16) + col;
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];"
: "=f"(tmp[0]), "=f"(tmp[1]), "=f"(tmp[2]), "=f"(tmp[3]),
"=f"(tmp[4]), "=f"(tmp[5]), "=f"(tmp[6]), "=f"(tmp[7])
: "r"(addr));
asm volatile("tcgen05.wait::ld.sync.aligned;");
int out_row = wid * 32 + lane;
if (n < 2 && out_row < 128) { // First 16 cols
for (int c = 0; c < 8; c++) {
s_out[out_row * 16 + n * 8 + c] = tmp[c] * scale;
}
}
}
__syncthreads();
// Scalar reference
if (tid == 0) {
float* q_row = new float[hd];
for (int d = 0; d < hd; d++) q_row[d] = bf16_to_f32(q[d]);
for (int c = 0; c < sk; c++) {
float dot = 0.0f;
for (int d = 0; d < hd; d++)
dot += q_row[d] * bf16_to_f32(k[c * hd + d]);
s_scalar[c] = dot * scale;
}
delete[] q_row;
}
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== UMMA QK GEMM Test (HD=64, 4 K-tiles) ===\n");
const int HD = 64, SK = 128;
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)malloc(SK * HD * sizeof(bf16_t));
float* h_s_out = (float*)calloc(128 * 16, sizeof(float));
float* h_s_scalar = (float*)calloc(SK, sizeof(float));
srand(42);
for (int d = 0; d < HD; d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
for (int i = 0; i < SK*HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
bf16_t *d_q, *d_k; float *d_s_out, *d_s_scalar;
cudaMalloc(&d_q, HD*sizeof(bf16_t)); cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
cudaMalloc(&d_s_out, 128*16*sizeof(float)); cudaMalloc(&d_s_scalar, SK*sizeof(float));
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemset(d_s_out, 0, 128*16*sizeof(float)); cudaMemset(d_s_scalar, 0, SK*sizeof(float));
// SMEM: sTmemBase(4) + pad(16) + sQ(128*64*2=16384) + sK(128*64*2=16384) + pad
int smem = (4 + 16 + 128*HD*2 + 128*HD*2 + 256 + 127) & ~127;
printf("SMEM: %d bytes\n", smem);
test_umma_qk_hd64<<<1, N_WARPS*32, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE, HD, SK);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
cudaMemcpy(h_s_out, d_s_out, 128*16*sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(h_s_scalar, d_s_scalar, SK*sizeof(float), cudaMemcpyDeviceToHost);
// Compare row 0
printf("S[0,0..7] MMA: ");
for (int c = 0; c < 8; c++) printf("%.6f ", h_s_out[0*16+c]);
printf("\nS[0,0..7] ref: ");
for (int c = 0; c < 8; c++) printf("%.6f ", h_s_scalar[c]);
printf("\n");
float max_diff = 0.0f, max_val = 0.0f;
for (int c = 0; c < 16; c++) {
max_diff = fmaxf(max_diff, fabsf(h_s_out[0*16+c] - h_s_scalar[c]));
max_val = fmaxf(max_val, fabsf(h_s_scalar[c]));
}
float rel_err = max_val > 0 ? max_diff / max_val : max_diff;
printf("Row 0 rel err (16 cols): %.6f\n", rel_err);
printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar);
free(h_q); free(h_k); free(h_s_out); free(h_s_scalar);
return 0;
}