test: HD=64 full 4 K-tile accumulate + full-HD scalar reference
This commit is contained in:
@@ -61,19 +61,27 @@ test_umma_qk_hd64_1ktile(const bf16_t* q, const bf16_t* k,
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// MMA with K-tile 0 (columns 0-15)
|
||||
// Descriptor: start = sQ_smem (base of the matrix)
|
||||
// Descriptors
|
||||
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
|
||||
uint32_t sK_smem = __cvta_generic_to_shared(sK);
|
||||
uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128);
|
||||
uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128);
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
|
||||
if (lane == 0) {
|
||||
umma_ss_f16(tb, desc_q, desc_k, idesc, false);
|
||||
// K-tile loop with accumulate
|
||||
for (int kt = 0; kt < hd / 16; kt++) {
|
||||
// K-tile kt: columns [16*kt, 16*kt+16)
|
||||
// In canonical layout, columns start at core_k = 2*kt and 2*kt+1
|
||||
// Offset = 2*kt * 2048 bytes from matrix base
|
||||
uint32_t q_kt = sQ_smem + kt * 4096; // 2 core cols * 2048 bytes = 4096 per K-tile
|
||||
uint32_t k_kt = sK_smem + kt * 4096;
|
||||
uint64_t dq = make_umma_desc_kmajor_none(q_kt, 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(k_kt, 128);
|
||||
|
||||
if (lane == 0) {
|
||||
umma_ss_f16(tb, dq, dk, idesc, kt > 0);
|
||||
}
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
}
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Read TMEM
|
||||
for (int n = 0; n < 128 / 8; n++) {
|
||||
@@ -92,11 +100,11 @@ test_umma_qk_hd64_1ktile(const bf16_t* q, const bf16_t* k,
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Scalar: S[0,j] = sum(Q[0,d]*K[j,d], d=0..15) * scale (first K-tile only)
|
||||
// Scalar: S[0,j] = sum(Q[0,d]*K[j,d], d=0..hd-1) * scale (full HD)
|
||||
if (tid == 0) {
|
||||
for (int j = 0; j < sk; j++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < 16; d++) // Only first K-tile
|
||||
for (int d = 0; d < hd; d++)
|
||||
dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * hd + d]);
|
||||
s_scalar[j] = dot * scale;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user