test: separate (128,16) SMEM per K-tile with correct source stride

This commit is contained in:
2026-05-28 12:57:38 +00:00
parent f244c4fdd2
commit c936940428

View File

@@ -1,15 +1,8 @@
/**
* UMMA QK GEMM Test — HD=64 (4 K-tiles)
* UMMA QK GEMM Test — HD=64 (4 K-tiles, separate SMEM per K-tile)
*
* Full multi-K-tile QK GEMM with proper SMEM writes.
* Key fix: source data stride ≠ SMEM tile width — must write manually.
*
* Pipeline:
* 1. Load Q (1, 64) into (128, 64) canonical SMEM
* 2. Load K (128, 64) into (128, 64) canonical SMEM
* 3. For each K-tile (16 BF16): construct offset descriptor, call MMA with accumulate
* 4. Read S from TMEM, apply 1/sqrt(HD) scale
* 5. Compare against scalar reference
* Each K-tile gets its own (128, 16) SMEM region — no offset descriptors.
* Source data stride handled correctly (SRC_HD=64, SMEM_HD=16).
*/
#include <cuda_runtime.h>
@@ -29,43 +22,7 @@ constexpr int HD = 64;
constexpr int SK = 128;
constexpr int NKT = HD / MMA_K_BF16; // 4
constexpr int BLOCK_MN = 128;
/**
* Write Q (1, SRC_HD) into (128, SMEM_HD) canonical layout.
* Only row 0 has data. Source stride = SRC_HD, SMEM cols = SMEM_HD.
*/
template<int SMEM_HD, int SRC_HD>
__device__ void write_q_canonical(bf16_t* dst, const bf16_t* q) {
constexpr int CORES_MN = 128 / 8; // 16
constexpr int CORES_K = SMEM_HD / 8;
// Zero all
for (int i = threadIdx.x; i < 128 * SMEM_HD; i += 128) dst[i] = 0;
// Row 0 only: core_mn=0, local_r=0
for (int c = threadIdx.x; c < SRC_HD; c += 128) {
int ck = c / 8, lc = c % 8;
dst[ck * CORES_MN * 64 + lc] = q[c];
}
}
/**
* Write K (SK, SRC_HD) into (128, SMEM_HD) canonical layout.
* Source stride = SRC_HD, SMEM cols = SMEM_HD.
*/
template<int SMEM_HD, int SRC_HD, int SK_VAL>
__device__ void write_k_canonical(bf16_t* dst, const bf16_t* k) {
constexpr int CORES_MN = 128 / 8; // 16
// Zero all
for (int i = threadIdx.x; i < 128 * SMEM_HD; i += 128) dst[i] = 0;
// Write actual rows
for (int i = threadIdx.x; i < SK_VAL * SMEM_HD; i += 128) {
int r = i / SMEM_HD;
int c = i % SMEM_HD;
if (r >= SK_VAL || c >= SRC_HD) continue;
int ck = c / 8, lc = c % 8;
int tmn = r / 8, lr = r % 8;
dst[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k[r * SRC_HD + c];
}
}
constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; // 128*16 = 2048 BF16 per K-tile
__global__ void __launch_bounds__(128)
test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
@@ -74,52 +31,70 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
const int tid = threadIdx.x;
const int wid = tid / WARP, lane = tid % WARP;
// SMEM: tmem_base(4) + pad(12) + Q tiles (4 × 2048 BF16) + K tiles (4 × 2048 BF16)
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK = sQ + 128 * HD; // (128, 64) each = 16384 bytes
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sQ1 = sQ0 + TILE_SZ;
bf16_t* sQ2 = sQ1 + TILE_SZ;
bf16_t* sQ3 = sQ2 + TILE_SZ;
bf16_t* sK0 = sQ3 + TILE_SZ;
bf16_t* sK1 = sK0 + TILE_SZ;
bf16_t* sK2 = sK1 + TILE_SZ;
bf16_t* sK3 = sK2 + TILE_SZ;
// Load Q (1, 64) → (128, 64) canonical
write_q_canonical<HD, HD>(sQ, q);
// Load K (128, 64) → (128, 64) canonical
write_k_canonical<HD, HD, SK>(sK, k);
constexpr int CORES_MN = 16; // 128/8
// Load Q K-tiles: Q is (1, 64), each K-tile takes 16 dims
// Zero all tiles
for (int i = tid; i < NKT * TILE_SZ; i += 128) { sQ0[i] = 0; sK0[i] = 0; }
__syncthreads();
// TMEM alloc — 128 columns for (128, 128) Layout D output
// Write Q row 0 to each K-tile's SMEM
for (int kt = 0; kt < NKT; kt++) {
bf16_t* sq = sQ0 + kt * TILE_SZ;
for (int d = tid; d < MMA_K_BF16; d += 128) {
int ck = d / 8, lc = d % 8;
sq[ck * CORES_MN * 64 + lc] = q[kt * MMA_K_BF16 + d];
}
// Write K for this K-tile: K[r, 16*kt + d] for r=0..127, d=0..15
bf16_t* sk = sK0 + kt * TILE_SZ;
for (int r = 0; r < SK; r++) {
for (int d = tid; d < MMA_K_BF16; d += 128) {
int ck = d / 8, lc = d % 8;
int tmn = r / 8, lr = r % 8;
sk[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * MMA_K_BF16 + d];
}
}
}
__syncthreads();
// TMEM alloc
if (wid == 1) {
tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
}
__syncthreads();
uint32_t tb = *sTmemBase;
// Multi-K-tile QK GEMM
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
uint32_t sK_smem = __cvta_generic_to_shared(sK);
// Multi-K-tile QK GEMM with separate SMEM per K-tile
bf16_t* sQ_arr[NKT] = {sQ0, sQ1, sQ2, sQ3};
bf16_t* sK_arr[NKT] = {sK0, sK1, sK2, sK3};
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < 1; kt++) { // DEBUG: single K-tile from full SMEM
// K-tile offset in canonical layout:
// Each 16-BF16 K-tile spans 2 core columns.
// Core column 2*kt starts at offset 2*kt * (128/8 * 128) bytes = 2*kt * 2048 bytes = kt * 4096 bytes.
uint32_t q_addr = sQ_smem + kt * BLOCK_MN * 32;
uint32_t k_addr = sK_smem + kt * BLOCK_MN * 32;
for (int kt = 0; kt < NKT; kt++) {
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ_arr[kt]), BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK_arr[kt]), BLOCK_MN);
uint64_t dq = make_umma_desc_kmajor_none(q_addr, BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(k_addr, BLOCK_MN);
// Single thread calls MMA (gau-nernst's elect_one pattern)
if (tid == 0) {
umma_ss_f16(tb, dq, dk, idesc, kt > 0);
}
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
// Final fence before TMEM read
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Read S from TMEM (Layout D: 32x32b.x8)
// Read S from TMEM
for (int n = 0; n < 128 / 8; n++) {
const int row = wid * 32;
const int col = n * 8;
@@ -144,7 +119,7 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
if (tid == 0) {
for (int j = 0; j < SK; j++) {
float dot = 0.0f;
for (int d = 0; d < 16; d++) // DEBUG: single K-tile
for (int d = 0; d < HD; d++)
dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * HD + d]);
s_scalar[j] = dot * scale;
}
@@ -154,7 +129,7 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
}
int main() {
printf("=== UMMA QK GEMM HD=64 (4 K-tiles, fixed stride) ===\n");
printf("=== UMMA QK GEMM HD=64 (separate SMEM per K-tile) ===\n");
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t));
@@ -172,7 +147,8 @@ int main() {
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
int smem = (4 + 16 + 2 * 128 * HD * sizeof(bf16_t) + 256 + 127) & ~127;
// SMEM: 4 + 12(pad) + 8 * 2048*2
int smem = (4 + 16 + NKT * 2 * TILE_SZ * sizeof(bf16_t) + 256 + 127) & ~127;
printf("SMEM: %d bytes (%d KB)\n", smem, smem / 1024);
test_umma_hd64<<<1, 128, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE);