test: separate (128,16) SMEM per K-tile with correct source stride
This commit is contained in:
@@ -1,15 +1,8 @@
|
||||
/**
|
||||
* UMMA QK GEMM Test — HD=64 (4 K-tiles)
|
||||
* UMMA QK GEMM Test — HD=64 (4 K-tiles, separate SMEM per K-tile)
|
||||
*
|
||||
* Full multi-K-tile QK GEMM with proper SMEM writes.
|
||||
* Key fix: source data stride ≠ SMEM tile width — must write manually.
|
||||
*
|
||||
* Pipeline:
|
||||
* 1. Load Q (1, 64) into (128, 64) canonical SMEM
|
||||
* 2. Load K (128, 64) into (128, 64) canonical SMEM
|
||||
* 3. For each K-tile (16 BF16): construct offset descriptor, call MMA with accumulate
|
||||
* 4. Read S from TMEM, apply 1/sqrt(HD) scale
|
||||
* 5. Compare against scalar reference
|
||||
* Each K-tile gets its own (128, 16) SMEM region — no offset descriptors.
|
||||
* Source data stride handled correctly (SRC_HD=64, SMEM_HD=16).
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@@ -29,43 +22,7 @@ constexpr int HD = 64;
|
||||
constexpr int SK = 128;
|
||||
constexpr int NKT = HD / MMA_K_BF16; // 4
|
||||
constexpr int BLOCK_MN = 128;
|
||||
|
||||
/**
|
||||
* Write Q (1, SRC_HD) into (128, SMEM_HD) canonical layout.
|
||||
* Only row 0 has data. Source stride = SRC_HD, SMEM cols = SMEM_HD.
|
||||
*/
|
||||
template<int SMEM_HD, int SRC_HD>
|
||||
__device__ void write_q_canonical(bf16_t* dst, const bf16_t* q) {
|
||||
constexpr int CORES_MN = 128 / 8; // 16
|
||||
constexpr int CORES_K = SMEM_HD / 8;
|
||||
// Zero all
|
||||
for (int i = threadIdx.x; i < 128 * SMEM_HD; i += 128) dst[i] = 0;
|
||||
// Row 0 only: core_mn=0, local_r=0
|
||||
for (int c = threadIdx.x; c < SRC_HD; c += 128) {
|
||||
int ck = c / 8, lc = c % 8;
|
||||
dst[ck * CORES_MN * 64 + lc] = q[c];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Write K (SK, SRC_HD) into (128, SMEM_HD) canonical layout.
|
||||
* Source stride = SRC_HD, SMEM cols = SMEM_HD.
|
||||
*/
|
||||
template<int SMEM_HD, int SRC_HD, int SK_VAL>
|
||||
__device__ void write_k_canonical(bf16_t* dst, const bf16_t* k) {
|
||||
constexpr int CORES_MN = 128 / 8; // 16
|
||||
// Zero all
|
||||
for (int i = threadIdx.x; i < 128 * SMEM_HD; i += 128) dst[i] = 0;
|
||||
// Write actual rows
|
||||
for (int i = threadIdx.x; i < SK_VAL * SMEM_HD; i += 128) {
|
||||
int r = i / SMEM_HD;
|
||||
int c = i % SMEM_HD;
|
||||
if (r >= SK_VAL || c >= SRC_HD) continue;
|
||||
int ck = c / 8, lc = c % 8;
|
||||
int tmn = r / 8, lr = r % 8;
|
||||
dst[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k[r * SRC_HD + c];
|
||||
}
|
||||
}
|
||||
constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; // 128*16 = 2048 BF16 per K-tile
|
||||
|
||||
__global__ void __launch_bounds__(128)
|
||||
test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
@@ -74,52 +31,70 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / WARP, lane = tid % WARP;
|
||||
|
||||
// SMEM: tmem_base(4) + pad(12) + Q tiles (4 × 2048 BF16) + K tiles (4 × 2048 BF16)
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||||
bf16_t* sK = sQ + 128 * HD; // (128, 64) each = 16384 bytes
|
||||
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||||
bf16_t* sQ1 = sQ0 + TILE_SZ;
|
||||
bf16_t* sQ2 = sQ1 + TILE_SZ;
|
||||
bf16_t* sQ3 = sQ2 + TILE_SZ;
|
||||
bf16_t* sK0 = sQ3 + TILE_SZ;
|
||||
bf16_t* sK1 = sK0 + TILE_SZ;
|
||||
bf16_t* sK2 = sK1 + TILE_SZ;
|
||||
bf16_t* sK3 = sK2 + TILE_SZ;
|
||||
|
||||
// Load Q (1, 64) → (128, 64) canonical
|
||||
write_q_canonical<HD, HD>(sQ, q);
|
||||
// Load K (128, 64) → (128, 64) canonical
|
||||
write_k_canonical<HD, HD, SK>(sK, k);
|
||||
constexpr int CORES_MN = 16; // 128/8
|
||||
|
||||
// Load Q K-tiles: Q is (1, 64), each K-tile takes 16 dims
|
||||
// Zero all tiles
|
||||
for (int i = tid; i < NKT * TILE_SZ; i += 128) { sQ0[i] = 0; sK0[i] = 0; }
|
||||
__syncthreads();
|
||||
|
||||
// TMEM alloc — 128 columns for (128, 128) Layout D output
|
||||
// Write Q row 0 to each K-tile's SMEM
|
||||
for (int kt = 0; kt < NKT; kt++) {
|
||||
bf16_t* sq = sQ0 + kt * TILE_SZ;
|
||||
for (int d = tid; d < MMA_K_BF16; d += 128) {
|
||||
int ck = d / 8, lc = d % 8;
|
||||
sq[ck * CORES_MN * 64 + lc] = q[kt * MMA_K_BF16 + d];
|
||||
}
|
||||
// Write K for this K-tile: K[r, 16*kt + d] for r=0..127, d=0..15
|
||||
bf16_t* sk = sK0 + kt * TILE_SZ;
|
||||
for (int r = 0; r < SK; r++) {
|
||||
for (int d = tid; d < MMA_K_BF16; d += 128) {
|
||||
int ck = d / 8, lc = d % 8;
|
||||
int tmn = r / 8, lr = r % 8;
|
||||
sk[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * MMA_K_BF16 + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// TMEM alloc
|
||||
if (wid == 1) {
|
||||
tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Multi-K-tile QK GEMM
|
||||
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
|
||||
uint32_t sK_smem = __cvta_generic_to_shared(sK);
|
||||
// Multi-K-tile QK GEMM with separate SMEM per K-tile
|
||||
bf16_t* sQ_arr[NKT] = {sQ0, sQ1, sQ2, sQ3};
|
||||
bf16_t* sK_arr[NKT] = {sK0, sK1, sK2, sK3};
|
||||
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
|
||||
|
||||
for (int kt = 0; kt < 1; kt++) { // DEBUG: single K-tile from full SMEM
|
||||
// K-tile offset in canonical layout:
|
||||
// Each 16-BF16 K-tile spans 2 core columns.
|
||||
// Core column 2*kt starts at offset 2*kt * (128/8 * 128) bytes = 2*kt * 2048 bytes = kt * 4096 bytes.
|
||||
uint32_t q_addr = sQ_smem + kt * BLOCK_MN * 32;
|
||||
uint32_t k_addr = sK_smem + kt * BLOCK_MN * 32;
|
||||
for (int kt = 0; kt < NKT; kt++) {
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ_arr[kt]), BLOCK_MN);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK_arr[kt]), BLOCK_MN);
|
||||
|
||||
uint64_t dq = make_umma_desc_kmajor_none(q_addr, BLOCK_MN);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(k_addr, BLOCK_MN);
|
||||
|
||||
// Single thread calls MMA (gau-nernst's elect_one pattern)
|
||||
if (tid == 0) {
|
||||
umma_ss_f16(tb, dq, dk, idesc, kt > 0);
|
||||
}
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Final fence before TMEM read
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Read S from TMEM (Layout D: 32x32b.x8)
|
||||
// Read S from TMEM
|
||||
for (int n = 0; n < 128 / 8; n++) {
|
||||
const int row = wid * 32;
|
||||
const int col = n * 8;
|
||||
@@ -144,7 +119,7 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
if (tid == 0) {
|
||||
for (int j = 0; j < SK; j++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < 16; d++) // DEBUG: single K-tile
|
||||
for (int d = 0; d < HD; d++)
|
||||
dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j * HD + d]);
|
||||
s_scalar[j] = dot * scale;
|
||||
}
|
||||
@@ -154,7 +129,7 @@ test_umma_hd64(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("=== UMMA QK GEMM HD=64 (4 K-tiles, fixed stride) ===\n");
|
||||
printf("=== UMMA QK GEMM HD=64 (separate SMEM per K-tile) ===\n");
|
||||
const float SCALE = 1.0f / sqrtf((float)HD);
|
||||
|
||||
bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t));
|
||||
@@ -172,7 +147,8 @@ int main() {
|
||||
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
int smem = (4 + 16 + 2 * 128 * HD * sizeof(bf16_t) + 256 + 127) & ~127;
|
||||
// SMEM: 4 + 12(pad) + 8 * 2048*2
|
||||
int smem = (4 + 16 + NKT * 2 * TILE_SZ * sizeof(bf16_t) + 256 + 127) & ~127;
|
||||
printf("SMEM: %d bytes (%d KB)\n", smem, smem / 1024);
|
||||
|
||||
test_umma_hd64<<<1, 128, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE);
|
||||
|
||||
Reference in New Issue
Block a user