From 5a65d46c26e5df6fa0f9d0d331d3352fe2c217ab Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 12:18:06 +0000 Subject: [PATCH] =?UTF-8?q?test:=20HD=3D64=20with=20separate=20SMEM=20per?= =?UTF-8?q?=20K-tile=20=E2=80=94=20no=20offset=20descriptors=20needed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_umma_qk_hd64.cu | 112 ++++++++++++++------------------ 1 file changed, 50 insertions(+), 62 deletions(-) diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu index 7bba6c54..b2f74ad7 100644 --- a/tests/unit/test_umma_qk_hd64.cu +++ b/tests/unit/test_umma_qk_hd64.cu @@ -1,7 +1,7 @@ /** - * UMMA QK GEMM Test — HD=64, SK=128, 1 K-tile (columns 0-15 only) - * Verifies the K-tile descriptor offset is correct by comparing - * against the HD=16 result (which is proven correct). + * UMMA QK GEMM Test — HD=64 (4 K-tiles), separate SMEM per K-tile. + * Each K-tile has its own (128, 16) SMEM region. + * This avoids the offset descriptor issue. */ #include @@ -18,64 +18,56 @@ using namespace dsv4::kernels::attention; static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } +constexpr int NKT = 4; // hd=64 / 16 + __global__ void __launch_bounds__(128) -test_umma_qk_hd64_1ktile(const bf16_t* q, const bf16_t* k, - float* s_out, float* s_scalar, float scale, int hd, int sk) +test_umma_hd64(const bf16_t* q, const bf16_t* k, + float* s_out, float* s_scalar, float scale) { const int tid = threadIdx.x; const int wid = tid / 32, lane = tid % 32; + // Separate SMEM per K-tile: sQ[4][128*16] and sK[4][128*16] + // Each (128, 16) = 4096 bytes in canonical layout extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; - bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); - bf16_t* sK = sQ + 128 * hd; + // 4 Q K-tiles + 4 K K-tiles, each (128, 16) BF16 = 4096 bytes + bf16_t* sQ_base = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); + bf16_t* sK_base = sQ_base + NKT * 128 * 16; // TMEM alloc - if (wid == 0) { - tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128); - } + if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128); __syncthreads(); uint32_t tb = *sTmemBase; - // Zero SMEM - for (int i = tid; i < 128 * hd; i += 128) { sQ[i] = 0; sK[i] = 0; } - __syncthreads(); + // Load each K-tile separately + for (int kt = 0; kt < NKT; kt++) { + bf16_t* sQ = sQ_base + kt * 128 * 16; + bf16_t* sK = sK_base + kt * 128 * 16; - // Write Q (1, hd) to sQ row 0 in canonical layout - for (int d = tid; d < hd; d += 128) { - int ck = d / 8, lc = d % 8; - sQ[ck * 16 * 64 + lc] = q[d]; - } - // Write K (sk, hd) to sK in canonical layout - for (int i = tid; i < sk * hd; i += 128) { - int r = i / hd, c = i % hd; - int tmn = r / 8, ck = c / 8, lr = r % 8, lc = c % 8; - sK[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[i]; - } - __syncthreads(); + // Zero this K-tile + for (int i = tid; i < 128 * 16; i += 128) { sQ[i] = 0; sK[i] = 0; } + __syncthreads(); - // Verify Q[0..7] in SMEM - if (tid == 0) { - for (int d = 0; d < 8; d++) s_out[200+d] = bf16_to_f32(sQ[d]); - for (int d = 0; d < 8; d++) s_out[208+d] = bf16_to_f32(sQ[1024+d]); // Q[8..15] - } - __syncthreads(); + // Write Q's K-tile: row 0, columns [16*kt, 16*kt+16) + for (int d = tid; d < 16; d += 128) { + int ck = d / 8, lc = d % 8; + sQ[ck * 16 * 64 + lc] = q[kt * 16 + d]; + } + // Write K's K-tile: all rows, columns [16*kt, 16*kt+16) + for (int i = tid; i < 128 * 16; i += 128) { + int r = i / 16, c = i % 16; + int tmn = r / 8, ck = c / 8, lr = r % 8, lc = c % 8; + sK[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[r * 64 + kt * 16 + c]; + } + __syncthreads(); - // Descriptors - uint32_t sQ_smem = __cvta_generic_to_shared(sQ); - uint32_t sK_smem = __cvta_generic_to_shared(sK); - uint32_t idesc = make_idesc(128, 128); - - // K-tile loop with accumulate - for (int kt = 0; kt < hd / 16; kt++) { // Full K-tile loop - // K-tile kt: columns [16*kt, 16*kt+16) - // In canonical layout, columns start at core_k = 2*kt and 2*kt+1 - // Offset = 2*kt * 2048 bytes from matrix base - uint32_t q_kt = sQ_smem + kt * 4096; // 2 core cols * 2048 bytes = 4096 per K-tile - uint32_t k_kt = sK_smem + kt * 4096; - uint64_t dq = make_umma_desc_kmajor_none(q_kt, 128); - uint64_t dk = make_umma_desc_kmajor_none(k_kt, 128); + // Construct descriptor for this K-tile + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ), 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128); + uint32_t idesc = make_idesc(128, 128); + // MMA if (lane == 0) { umma_ss_f16(tb, dq, dk, idesc, kt > 0); } @@ -100,22 +92,21 @@ test_umma_qk_hd64_1ktile(const bf16_t* q, const bf16_t* k, } __syncthreads(); - // Scalar: S[0,j] = sum(Q[0,d]*K[j,d], d=0..hd-1) * scale (full HD) + // Scalar: S[0,j] = sum(Q[0,d]*K[j,d], d=0..63) * scale if (tid == 0) { - for (int j = 0; j < sk; j++) { + for (int j = 0; j < 128; j++) { float dot = 0.0f; - for (int d = 0; d < hd; d++) - dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * hd + d]); + for (int d = 0; d < 64; d++) + dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * 64 + d]); s_scalar[j] = dot * scale; } } - __syncthreads(); if (wid == 0) tmem_dealloc(tb, 128); } int main() { - printf("=== UMMA QK HD=64, 1 K-tile ===\n"); + printf("=== UMMA QK HD=64 (separate SMEM per K-tile) ===\n"); const int HD = 64, SK = 128; const float SCALE = 1.0f / sqrtf((float)HD); @@ -134,8 +125,11 @@ int main() { cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice); cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice); - int smem = (4+16+128*HD*2+128*HD*2+256+127)&~127; - test_umma_qk_hd64_1ktile<<<1, 128, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE, HD, SK); + // SMEM: 4 + 16 + 8 * 128*16*2 + 256 + int smem = (4+16 + NKT*2*128*16*2 + 256 + 127) & ~127; + printf("SMEM: %d bytes\n", smem); + + test_umma_hd64<<<1, 128, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE); cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } @@ -143,22 +137,16 @@ int main() { cudaMemcpy(h_s_out, d_s_out, 128*16*sizeof(float), cudaMemcpyDeviceToHost); cudaMemcpy(h_s_scalar, d_s_scalar, SK*sizeof(float), cudaMemcpyDeviceToHost); - // SMEM verify - printf("Q[0..7] SMEM: "); for(int d=0;d<8;d++) printf("%.4f ",h_s_out[200+d]); printf("\n"); - printf("Q[0..7] orig: "); for(int d=0;d<8;d++) printf("%.4f ",bf16_to_f32_host(h_q[d])); printf("\n"); - printf("Q[8..15] SMEM: "); for(int d=0;d<8;d++) printf("%.4f ",h_s_out[208+d]); printf("\n"); - printf("Q[8..15] orig: "); for(int d=0;d<8;d++) printf("%.4f ",bf16_to_f32_host(h_q[8+d])); printf("\n"); - - // Compare printf("S[0,0..7] MMA: "); for(int c=0;c<8;c++) printf("%.6f ",h_s_out[0*16+c]); printf("\n"); printf("S[0,0..7] ref: "); for(int c=0;c<8;c++) printf("%.6f ",h_s_scalar[c]); printf("\n"); + float max_diff = 0.0f, max_val = 0.0f; - for (int c = 0; c < 8; c++) { + for (int c = 0; c < 16; c++) { max_diff = fmaxf(max_diff, fabsf(h_s_out[0*16+c] - h_s_scalar[c])); max_val = fmaxf(max_val, fabsf(h_s_scalar[c])); } float rel_err = max_val > 0 ? max_diff / max_val : max_diff; - printf("Row 0 rel err: %.6f\n", rel_err); + printf("Row 0 rel err (16 cols): %.6f\n", rel_err); printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED"); cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar);