test: HD=64 with separate SMEM per K-tile — no offset descriptors needed

This commit is contained in:
2026-05-28 12:18:06 +00:00
parent 526fafb808
commit 5a65d46c26

View File

@@ -1,7 +1,7 @@
/**
* UMMA QK GEMM Test — HD=64, SK=128, 1 K-tile (columns 0-15 only)
* Verifies the K-tile descriptor offset is correct by comparing
* against the HD=16 result (which is proven correct).
* UMMA QK GEMM Test — HD=64 (4 K-tiles), separate SMEM per K-tile.
* Each K-tile has its own (128, 16) SMEM region.
* This avoids the offset descriptor issue.
*/
#include <cuda_runtime.h>
@@ -18,64 +18,56 @@ using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int NKT = 4; // hd=64 / 16
__global__ void __launch_bounds__(128)
test_umma_qk_hd64_1ktile(const bf16_t* q, const bf16_t* k,
float* s_out, float* s_scalar, float scale, int hd, int sk)
test_umma_hd64(const bf16_t* q, const bf16_t* k,
float* s_out, float* s_scalar, float scale)
{
const int tid = threadIdx.x;
const int wid = tid / 32, lane = tid % 32;
// Separate SMEM per K-tile: sQ[4][128*16] and sK[4][128*16]
// Each (128, 16) = 4096 bytes in canonical layout
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK = sQ + 128 * hd;
// 4 Q K-tiles + 4 K K-tiles, each (128, 16) BF16 = 4096 bytes
bf16_t* sQ_base = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK_base = sQ_base + NKT * 128 * 16;
// TMEM alloc
if (wid == 0) {
tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
}
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// Zero SMEM
for (int i = tid; i < 128 * hd; i += 128) { sQ[i] = 0; sK[i] = 0; }
__syncthreads();
// Load each K-tile separately
for (int kt = 0; kt < NKT; kt++) {
bf16_t* sQ = sQ_base + kt * 128 * 16;
bf16_t* sK = sK_base + kt * 128 * 16;
// Write Q (1, hd) to sQ row 0 in canonical layout
for (int d = tid; d < hd; d += 128) {
int ck = d / 8, lc = d % 8;
sQ[ck * 16 * 64 + lc] = q[d];
}
// Write K (sk, hd) to sK in canonical layout
for (int i = tid; i < sk * hd; i += 128) {
int r = i / hd, c = i % hd;
int tmn = r / 8, ck = c / 8, lr = r % 8, lc = c % 8;
sK[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[i];
}
__syncthreads();
// Zero this K-tile
for (int i = tid; i < 128 * 16; i += 128) { sQ[i] = 0; sK[i] = 0; }
__syncthreads();
// Verify Q[0..7] in SMEM
if (tid == 0) {
for (int d = 0; d < 8; d++) s_out[200+d] = bf16_to_f32(sQ[d]);
for (int d = 0; d < 8; d++) s_out[208+d] = bf16_to_f32(sQ[1024+d]); // Q[8..15]
}
__syncthreads();
// Write Q's K-tile: row 0, columns [16*kt, 16*kt+16)
for (int d = tid; d < 16; d += 128) {
int ck = d / 8, lc = d % 8;
sQ[ck * 16 * 64 + lc] = q[kt * 16 + d];
}
// Write K's K-tile: all rows, columns [16*kt, 16*kt+16)
for (int i = tid; i < 128 * 16; i += 128) {
int r = i / 16, c = i % 16;
int tmn = r / 8, ck = c / 8, lr = r % 8, lc = c % 8;
sK[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[r * 64 + kt * 16 + c];
}
__syncthreads();
// Descriptors
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
uint32_t sK_smem = __cvta_generic_to_shared(sK);
uint32_t idesc = make_idesc(128, 128);
// K-tile loop with accumulate
for (int kt = 0; kt < hd / 16; kt++) { // Full K-tile loop
// K-tile kt: columns [16*kt, 16*kt+16)
// In canonical layout, columns start at core_k = 2*kt and 2*kt+1
// Offset = 2*kt * 2048 bytes from matrix base
uint32_t q_kt = sQ_smem + kt * 4096; // 2 core cols * 2048 bytes = 4096 per K-tile
uint32_t k_kt = sK_smem + kt * 4096;
uint64_t dq = make_umma_desc_kmajor_none(q_kt, 128);
uint64_t dk = make_umma_desc_kmajor_none(k_kt, 128);
// Construct descriptor for this K-tile
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ), 128);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128);
uint32_t idesc = make_idesc(128, 128);
// MMA
if (lane == 0) {
umma_ss_f16(tb, dq, dk, idesc, kt > 0);
}
@@ -100,22 +92,21 @@ test_umma_qk_hd64_1ktile(const bf16_t* q, const bf16_t* k,
}
__syncthreads();
// Scalar: S[0,j] = sum(Q[0,d]*K[j,d], d=0..hd-1) * scale (full HD)
// Scalar: S[0,j] = sum(Q[0,d]*K[j,d], d=0..63) * scale
if (tid == 0) {
for (int j = 0; j < sk; j++) {
for (int j = 0; j < 128; j++) {
float dot = 0.0f;
for (int d = 0; d < hd; d++)
dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * hd + d]);
for (int d = 0; d < 64; d++)
dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * 64 + d]);
s_scalar[j] = dot * scale;
}
}
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== UMMA QK HD=64, 1 K-tile ===\n");
printf("=== UMMA QK HD=64 (separate SMEM per K-tile) ===\n");
const int HD = 64, SK = 128;
const float SCALE = 1.0f / sqrtf((float)HD);
@@ -134,8 +125,11 @@ int main() {
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
int smem = (4+16+128*HD*2+128*HD*2+256+127)&~127;
test_umma_qk_hd64_1ktile<<<1, 128, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE, HD, SK);
// SMEM: 4 + 16 + 8 * 128*16*2 + 256
int smem = (4+16 + NKT*2*128*16*2 + 256 + 127) & ~127;
printf("SMEM: %d bytes\n", smem);
test_umma_hd64<<<1, 128, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
@@ -143,22 +137,16 @@ int main() {
cudaMemcpy(h_s_out, d_s_out, 128*16*sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(h_s_scalar, d_s_scalar, SK*sizeof(float), cudaMemcpyDeviceToHost);
// SMEM verify
printf("Q[0..7] SMEM: "); for(int d=0;d<8;d++) printf("%.4f ",h_s_out[200+d]); printf("\n");
printf("Q[0..7] orig: "); for(int d=0;d<8;d++) printf("%.4f ",bf16_to_f32_host(h_q[d])); printf("\n");
printf("Q[8..15] SMEM: "); for(int d=0;d<8;d++) printf("%.4f ",h_s_out[208+d]); printf("\n");
printf("Q[8..15] orig: "); for(int d=0;d<8;d++) printf("%.4f ",bf16_to_f32_host(h_q[8+d])); printf("\n");
// Compare
printf("S[0,0..7] MMA: "); for(int c=0;c<8;c++) printf("%.6f ",h_s_out[0*16+c]); printf("\n");
printf("S[0,0..7] ref: "); for(int c=0;c<8;c++) printf("%.6f ",h_s_scalar[c]); printf("\n");
float max_diff = 0.0f, max_val = 0.0f;
for (int c = 0; c < 8; c++) {
for (int c = 0; c < 16; c++) {
max_diff = fmaxf(max_diff, fabsf(h_s_out[0*16+c] - h_s_scalar[c]));
max_val = fmaxf(max_val, fabsf(h_s_scalar[c]));
}
float rel_err = max_val > 0 ? max_diff / max_val : max_diff;
printf("Row 0 rel err: %.6f\n", rel_err);
printf("Row 0 rel err (16 cols): %.6f\n", rel_err);
printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar);