test_umma_qk: clean rewrite, hardcoded HD=16, explicit core-matrix layout writes
This commit is contained in:
@@ -4,15 +4,7 @@
|
||||
* Tests that tcgen05.mma produces correct QK attention scores.
|
||||
* Uses K-major, SWIZZLE_NONE, BF16 descriptors with core-matrix SMEM layout.
|
||||
*
|
||||
* Strategy:
|
||||
* 1. Create Q (1, HD) and K (SK, HD) on CPU, copy to GPU
|
||||
* 2. Load Q into SMEM (padded to 128 rows) in core-matrix layout
|
||||
* 3. Load K into SMEM (padded to 128 rows) in core-matrix layout
|
||||
* 4. For each K-tile (16 BF16 columns), call tcgen05.mma SS
|
||||
* 5. Read S from TMEM, compare against CPU reference
|
||||
*
|
||||
* First test: HD=16, SK=128 (single K-tile, single MMA call)
|
||||
* Second test: HD=64, SK=128 (4 K-tiles, accumulate)
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@@ -31,200 +23,131 @@ using namespace dsv4::kernels::attention;
|
||||
static bf16_t f32_to_bf16_host(float f) {
|
||||
uint32_t u;
|
||||
memcpy(&u, &f, 4);
|
||||
uint16_t h = (u >> 16) & 0xFFFF;
|
||||
return h;
|
||||
return (uint16_t)(u >> 16);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Test kernel: UMMA QK GEMM for HD=16, SK=128 (single K-tile)
|
||||
// ==================================================================
|
||||
// This is the simplest possible test: one MMA call, no K-tiling.
|
||||
// Q: (1, 16) padded to (128, 16) in SMEM
|
||||
// K: (128, 16) in SMEM (transposed view via K-major descriptor)
|
||||
// S = Q @ K^T: (128, 128) in TMEM → we only care about row 0
|
||||
|
||||
__global__ void __launch_bounds__(NTHREADS)
|
||||
test_umma_qk_hd16(
|
||||
const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
float* __restrict__ s_out, // output: S[0, 0..127] (row 0 of QK)
|
||||
float* __restrict__ s_scalar, // scalar reference: Q @ K^T computed in SMEM
|
||||
int sk, float scale
|
||||
float* __restrict__ s_out, // output: S[0, 0..127] from TMEM
|
||||
float* __restrict__ s_scalar, // scalar reference: Q @ K^T in SMEM
|
||||
float scale
|
||||
) {
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / WARP, lane = tid % WARP;
|
||||
|
||||
constexpr int HD = 16;
|
||||
constexpr int SK = 128;
|
||||
constexpr int KT = MMA_K_TILE; // 16
|
||||
constexpr int N_KTILES = HD / KT; // 1 for HD=16
|
||||
|
||||
// ================================================================
|
||||
// SMEM layout
|
||||
// ================================================================
|
||||
// sQ_ktile: (128, 16) BF16 in K-major core-matrix layout = 4 KB
|
||||
// sK_ktile: (128, 16) BF16 in K-major core-matrix layout = 4 KB
|
||||
// sTmemBase: 4 bytes
|
||||
// sQ_row: (HD) BF16 for scalar reference = 32 bytes
|
||||
// sK_row: (SK*HD) BF16 for scalar reference = 4 KB (we read from GMEM)
|
||||
// Total: ~8 KB + overhead, well within 232 KB
|
||||
// sQ_row: (16) floats for scalar reference
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
|
||||
// Align sQ_ktile to 16 bytes (required for UMMA descriptor)
|
||||
bf16_t* sQ_ktile = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||||
bf16_t* sK_ktile = sQ_ktile + 128 * KT; // After Q (128 * 16 BF16)
|
||||
float* sQ_row = (float*)(sK_ktile + 128 * KT); // After K
|
||||
bf16_t* sK_ktile = sQ_ktile + 128 * 16;
|
||||
float* sQ_row = (float*)(sK_ktile + 128 * 16);
|
||||
|
||||
// Load Q to sQ_row (float for scalar reference)
|
||||
for (int d = tid; d < HD; d += NTHREADS) {
|
||||
for (int d = tid; d < 16; d += NTHREADS) {
|
||||
sQ_row[d] = bf16_to_f32(q[d]);
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// TMEM allocation
|
||||
// TMEM allocation: 128 columns for S (128, 128)
|
||||
// ================================================================
|
||||
// QK result S: (128, 128) FP32 in TMEM = 128 columns
|
||||
// (We only need 128 columns for the full 128×128 result)
|
||||
constexpr int TMEM_COLS = 128;
|
||||
|
||||
if (wid == 0) {
|
||||
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
|
||||
tmem_alloc(smem_ptr, TMEM_COLS);
|
||||
tmem_alloc(smem_ptr, 128);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tmem_base = *sTmemBase;
|
||||
uint32_t tmem_s = tmem_base;
|
||||
|
||||
// Zero TMEM S accumulator
|
||||
if (wid == 0) {
|
||||
for (int col = 0; col < TMEM_COLS; col++) {
|
||||
tmem_store(tmem_s + col, 0, 0, 0, 0);
|
||||
for (int col = 0; col < 128; col++) {
|
||||
tmem_store(tmem_base + col, 0, 0, 0, 0);
|
||||
}
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// K-tile loop (just 1 tile for HD=16)
|
||||
// Load Q and K into SMEM in core-matrix layout
|
||||
// ================================================================
|
||||
for (int kt = 0; kt < N_KTILES; kt++) {
|
||||
// Zero SMEM K-tile buffers
|
||||
zero_smem<128, KT>(sQ_ktile);
|
||||
zero_smem<128, KT>(sK_ktile);
|
||||
__syncthreads();
|
||||
|
||||
// Write Q K-tile to SMEM in core-matrix layout
|
||||
// Q is (1, HD), padded to (128, 16). Row 0 has the actual data.
|
||||
// Only write row 0 (the actual query vector).
|
||||
// In core-matrix layout, row 0 is in core (0, 0) and core (0, 1).
|
||||
// Core (0, c) starts at offset c * 64, row 0 within it is at offset 0.
|
||||
// Row 0, column c*8 + j is at core (0, c), position 0*8 + j = j.
|
||||
// So core (0, 0) = sQ_ktile[0..7], core (0, 1) = sQ_ktile[64..71]
|
||||
for (int d = tid; d < KT; d += NTHREADS) {
|
||||
int c = d; // column within this K-tile
|
||||
int core_k = c / 8;
|
||||
int local_c = c % 8;
|
||||
int dst_idx = core_k * 64 + local_c; // tile_mn=0, so (0*CORES_K + core_k)*64 + 0*8 + local_c
|
||||
sQ_ktile[dst_idx] = q[kt * KT + c];
|
||||
}
|
||||
// Write K K-tile to SMEM in core-matrix layout
|
||||
// K is (SK, HD). We write rows 0..SK-1.
|
||||
write_smem_ktile<128>(sK_ktile, k, kt, HD);
|
||||
// ^ k + kt*KT is wrong — k is (SK, HD) row-major.
|
||||
// We need to extract columns [kt*16, kt*16+16) from K.
|
||||
// Let me fix: for each row r of K, the source is k[r * HD + kt * KT + c]
|
||||
// But write_smem_ktile expects (ROWS, hd) row-major with k_tile and hd.
|
||||
// Actually write_smem_ktile(src, k_tile, hd) reads src[r * hd + k_tile*KT + c].
|
||||
// But K in GMEM is k[r * HD + c]. We need k[r * HD + kt*KT + c].
|
||||
// So pass k as the base pointer and HD as the row stride.
|
||||
// write_smem_ktile expects src = (ROWS, hd) with k_tile selecting the column range.
|
||||
// This works: src[r * hd + k_tile * KT + c] = k[r * HD + kt*KT + c]. ✓
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// ============================================================
|
||||
// Construct UMMA descriptors
|
||||
// ============================================================
|
||||
// Both A (Q) and B (K) are in K-major core-matrix layout.
|
||||
// For the K^T operation: the hardware transposes via the descriptor.
|
||||
// For tcgen05.mma, A and B are BOTH K-major. The MMA computes:
|
||||
// D = A × B^T (for TN layout)
|
||||
// where A is (M, K) and B^T is (K, N).
|
||||
// But tcgen05.mma actually computes D = A × B^T when both are K-major?
|
||||
// No — the MMA computes D = A × B^T when A is MN-major and B is K-major,
|
||||
// or D = A^T × B^T when both are K-major.
|
||||
//
|
||||
// Actually, for tcgen05.mma with the instruction descriptor,
|
||||
// we can set transpose flags. Let me re-read the spec.
|
||||
//
|
||||
// The MMA computes: D = f(A) × f(B) + D
|
||||
// where f() can include transpose/negate based on idesc bits.
|
||||
// Without transpose: D = A × B^T (K-major for both = TN layout)
|
||||
//
|
||||
// For FMHA QK: S = Q × K^T
|
||||
// Q is (128, 16) — this is A
|
||||
// K is (128, 16) — we want K^T = (16, 128) for B
|
||||
//
|
||||
// With both K-major (TN layout), the MMA computes:
|
||||
// D = A × B^T = Q × K^T ✓
|
||||
//
|
||||
// Wait, "TN" means A is row-major (T) and B is col-major (N).
|
||||
// K-major is the SMEM layout, not the mathematical layout.
|
||||
// In the MMA's perspective:
|
||||
// A (K-major) = (M, K) with K contiguous → this IS row-major for A
|
||||
// B (K-major) = (N, K) with K contiguous → B^T = (K, N)
|
||||
// D = A × B^T = (M, K) × (K, N) = (M, N) ✓
|
||||
//
|
||||
// So with both descriptors as K-major, the MMA naturally computes Q × K^T.
|
||||
// No transpose flags needed.
|
||||
|
||||
uint32_t sQ_smem = __cvta_generic_to_shared(sQ_ktile);
|
||||
uint32_t sK_smem = __cvta_generic_to_shared(sK_ktile);
|
||||
|
||||
// For a (128, 16) K-tile in core-matrix layout:
|
||||
// LBO = 1, SBO = 2
|
||||
uint64_t desc_q = make_umma_desc_kmajor_none_ktile(sQ_smem);
|
||||
uint64_t desc_k = make_umma_desc_kmajor_none_ktile(sK_smem);
|
||||
|
||||
// ============================================================
|
||||
// Call tcgen05.mma SS (both SMEM → TMEM)
|
||||
// ============================================================
|
||||
// Only ONE thread calls MMA (single-threaded launch model)
|
||||
if (tid == 0) {
|
||||
bool accumulate = (kt > 0); // First tile: zero, rest: accumulate
|
||||
umma_ss_f16(tmem_s, desc_q, desc_k, accumulate);
|
||||
}
|
||||
__syncwarp(); // Ensure MMA is issued
|
||||
// Wait for MMA to complete
|
||||
if (wid == 0 && lane == 0) {
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
// Zero both buffers first
|
||||
for (int i = tid; i < 128 * 16; i += NTHREADS) {
|
||||
sQ_ktile[i] = 0;
|
||||
sK_ktile[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Write Q (1, 16) padded to (128, 16) in core-matrix layout
|
||||
// Core-matrix: each 8x8 BF16 tile at (tile_mn, tile_k) offset (tile_mn * 2 + tile_k) * 64
|
||||
// Q row 0, cols 0-15: tiles (0,0) and (0,1)
|
||||
// Tile (0,0): positions 0-7 = Q[0..7], Tile (0,1): positions 0-7 = Q[8..15]
|
||||
for (int d = tid; d < 16; d += NTHREADS) {
|
||||
int tile_k = d / 8;
|
||||
int local_c = d % 8;
|
||||
int dst_idx = (0 * 2 + tile_k) * 64 + 0 * 8 + local_c; // row 0 (local_r=0)
|
||||
sQ_ktile[dst_idx] = q[d];
|
||||
}
|
||||
|
||||
// Write K (128, 16) in core-matrix layout
|
||||
// K is (128, 16) row-major in GMEM. In core-matrix:
|
||||
// tile (r/8, c/8) at offset ((r/8) * 2 + (c/8)) * 64 + (r%8)*8 + (c%8)
|
||||
for (int i = tid; i < 128 * 16; i += NTHREADS) {
|
||||
int r = i / 16;
|
||||
int c = i % 16;
|
||||
int tile_mn = r / 8;
|
||||
int tile_k = c / 8;
|
||||
int local_r = r % 8;
|
||||
int local_c = c % 8;
|
||||
int dst_idx = (tile_mn * 2 + tile_k) * 64 + local_r * 8 + local_c;
|
||||
sK_ktile[dst_idx] = k[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// Construct UMMA descriptors and call MMA
|
||||
// ================================================================
|
||||
// Both A (Q) and B (K) are (128, 16) in K-major core-matrix layout.
|
||||
// For K-major NONE with BLOCK_K=16: LBO=1, SBO=2
|
||||
uint32_t sQ_smem = __cvta_generic_to_shared(sQ_ktile);
|
||||
uint32_t sK_smem = __cvta_generic_to_shared(sK_ktile);
|
||||
|
||||
uint64_t desc_q = make_umma_desc_kmajor_none_ktile(sQ_smem);
|
||||
uint64_t desc_k = make_umma_desc_kmajor_none_ktile(sK_smem);
|
||||
|
||||
// Single-threaded MMA launch
|
||||
if (tid == 0) {
|
||||
umma_ss_f16(tmem_base, desc_q, desc_k, /*accumulate=*/false);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Wait for MMA to complete
|
||||
if (wid == 0 && lane == 0) {
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// Read S from TMEM and write to output
|
||||
// ================================================================
|
||||
// S is (128, 128) FP32 in TMEM. Row 0 is the actual attention score.
|
||||
// TMEM lane mapping: lane i reads positions i*4+0..3 per column.
|
||||
// For row 0, lane 0 reads positions 0-3 from each column.
|
||||
// S is (128, 128) FP32 in TMEM. We care about row 0 (the query row).
|
||||
// TMEM column col stores 128 FP32. Lane 0's u0 = S[0, col].
|
||||
if (wid == 0) {
|
||||
for (int col = 0; col < 128; col++) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tmem_s + col, u0, u1, u2, u3);
|
||||
// Lane 0's u0 = row 0, position 0 in this column
|
||||
// But which row does this correspond to in the S matrix?
|
||||
// S is (128, 128) — row m, column n.
|
||||
// TMEM column col stores 128 FP32 values (rows 0-127 of column col).
|
||||
// Lane 0 reads positions 0-3: S[0,col], S[1,col], S[2,col], S[3,col]
|
||||
tmem_load(tmem_base + col, u0, u1, u2, u3);
|
||||
if (lane == 0) {
|
||||
float val = u32_to_f32(u0); // S[0, col]
|
||||
if (col < sk) {
|
||||
s_out[col] = val;
|
||||
}
|
||||
s_out[col] = u32_to_f32(u0); // S[0, col] (un-normalized)
|
||||
}
|
||||
}
|
||||
tmem_fence_load();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -232,10 +155,10 @@ test_umma_qk_hd16(
|
||||
// Scalar reference: compute Q @ K^T in SMEM (row 0 only)
|
||||
// ================================================================
|
||||
if (tid == 0) {
|
||||
for (int c = 0; c < sk; c++) {
|
||||
for (int c = 0; c < 128; c++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < HD; d++) {
|
||||
dot += sQ_row[d] * bf16_to_f32(k[c * HD + d]);
|
||||
for (int d = 0; d < 16; d++) {
|
||||
dot += sQ_row[d] * bf16_to_f32(k[c * 16 + d]);
|
||||
}
|
||||
s_scalar[c] = dot * scale;
|
||||
}
|
||||
@@ -244,7 +167,7 @@ test_umma_qk_hd16(
|
||||
|
||||
// TMEM dealloc
|
||||
if (wid == 0) {
|
||||
tmem_dealloc(tmem_base, TMEM_COLS);
|
||||
tmem_dealloc(tmem_base, 128);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,9 +178,9 @@ test_umma_qk_hd16(
|
||||
int main() {
|
||||
printf("=== UMMA QK GEMM Test (HD=16, SK=128) ===\n");
|
||||
|
||||
constexpr int HD = 16;
|
||||
constexpr int SK = 128;
|
||||
constexpr float SCALE = 1.0f / sqrtf((float)HD);
|
||||
const int HD = 16;
|
||||
const int SK = 128;
|
||||
const float SCALE = 1.0f / sqrtf((float)HD);
|
||||
|
||||
// Allocate host memory
|
||||
bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t));
|
||||
@@ -288,22 +211,16 @@ int main() {
|
||||
cudaMemset(d_s_scalar, 0, SK * sizeof(float));
|
||||
|
||||
// Compute SMEM size
|
||||
// sTmemBase: 4 bytes
|
||||
// sQ_ktile: 128 * 16 * 2 = 4096 bytes (aligned to 16B)
|
||||
// sK_ktile: 128 * 16 * 2 = 4096 bytes
|
||||
// sQ_row: 16 * 4 = 64 bytes
|
||||
// Total: 4 + 16 (alignment) + 4096 + 4096 + 64 = ~8276 bytes
|
||||
int smem_size = 4 + 16 + 128 * MMA_K_TILE * 2 + 128 * MMA_K_TILE * 2 + HD * 4 + 256; // extra padding
|
||||
smem_size = (smem_size + 127) & ~127; // align to 128B
|
||||
|
||||
// sTmemBase: 4 + alignment 16 + sQ: 128*16*2 + sK: 128*16*2 + sQ_row: 16*4 + padding
|
||||
int smem_size = 4 + 16 + 128*16*2 + 128*16*2 + 16*4 + 256;
|
||||
smem_size = (smem_size + 127) & ~127;
|
||||
printf("SMEM size: %d bytes\n", smem_size);
|
||||
printf("Launching kernel with %d threads, 1 CTA\n", NTHREADS);
|
||||
|
||||
// Launch
|
||||
dim3 grid(1, 1, 1);
|
||||
dim3 block(NTHREADS);
|
||||
test_umma_qk_hd16<<<grid, block, smem_size>>>(
|
||||
d_q, d_k, d_s_out, d_s_scalar, SK, SCALE);
|
||||
d_q, d_k, d_s_out, d_s_scalar, SCALE);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
@@ -322,8 +239,7 @@ int main() {
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", h_s_scalar[c]);
|
||||
printf("\n");
|
||||
|
||||
float max_diff = 0.0f;
|
||||
float max_val = 0.0f;
|
||||
float max_diff = 0.0f, max_val = 0.0f;
|
||||
for (int c = 0; c < SK; c++) {
|
||||
float diff = fabsf(h_s_out[c] - h_s_scalar[c]);
|
||||
max_diff = fmaxf(max_diff, diff);
|
||||
|
||||
Reference in New Issue
Block a user