From 8b2200a6d3f36979beadc8195412199e07930d4d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 12:07:50 +0000 Subject: [PATCH] test: HD=64 full 4 K-tile accumulate + full-HD scalar reference --- tests/unit/test_umma_qk_hd64.cu | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu index ee6a3db2..6bd49721 100644 --- a/tests/unit/test_umma_qk_hd64.cu +++ b/tests/unit/test_umma_qk_hd64.cu @@ -61,19 +61,27 @@ test_umma_qk_hd64_1ktile(const bf16_t* q, const bf16_t* k, } __syncthreads(); - // MMA with K-tile 0 (columns 0-15) - // Descriptor: start = sQ_smem (base of the matrix) + // Descriptors uint32_t sQ_smem = __cvta_generic_to_shared(sQ); uint32_t sK_smem = __cvta_generic_to_shared(sK); - uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128); - uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128); uint32_t idesc = make_idesc(128, 128); - if (lane == 0) { - umma_ss_f16(tb, desc_q, desc_k, idesc, false); + // K-tile loop with accumulate + for (int kt = 0; kt < hd / 16; kt++) { + // K-tile kt: columns [16*kt, 16*kt+16) + // In canonical layout, columns start at core_k = 2*kt and 2*kt+1 + // Offset = 2*kt * 2048 bytes from matrix base + uint32_t q_kt = sQ_smem + kt * 4096; // 2 core cols * 2048 bytes = 4096 per K-tile + uint32_t k_kt = sK_smem + kt * 4096; + uint64_t dq = make_umma_desc_kmajor_none(q_kt, 128); + uint64_t dk = make_umma_desc_kmajor_none(k_kt, 128); + + if (lane == 0) { + umma_ss_f16(tb, dq, dk, idesc, kt > 0); + } + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); } - asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); - __syncthreads(); // Read TMEM for (int n = 0; n < 128 / 8; n++) { @@ -92,11 +100,11 @@ test_umma_qk_hd64_1ktile(const bf16_t* q, const bf16_t* k, } __syncthreads(); - // Scalar: S[0,j] = sum(Q[0,d]*K[j,d], d=0..15) * scale (first K-tile only) + // Scalar: S[0,j] = sum(Q[0,d]*K[j,d], d=0..hd-1) * scale (full HD) if (tid == 0) { for (int j = 0; j < sk; j++) { float dot = 0.0f; - for (int d = 0; d < 16; d++) // Only first K-tile + for (int d = 0; d < hd; d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * hd + d]); s_scalar[j] = dot * scale; }