test: gau-nernst pattern — fence::after_thread_sync, 4 warps, 128 threads, 32x32b.x8 loop

This commit is contained in:
2026-05-28 11:28:47 +00:00
parent a048b56886
commit c01d6fddf4

View File

@@ -1,7 +1,10 @@
/**
* UMMA QK GEMM Test (HD=16, SK=128)
* Uses Layout D TMEM read (32x32b.x8, 4 warps) to correctly read MMA output.
* Debugging the 4× scaling factor observed with 16x256b read.
* Using gau-nernst's exact MMA + epilogue pattern:
* - tcgen05.fence::after_thread_sync before TMEM read
* - 32x32b.x8 TMEM read with row = warp_id * 32
* - tcgen05.wait::ld.sync.aligned after each read
* - Output: off_m + tid per thread (each thread = one row)
*/
#include <cuda_runtime.h>
@@ -18,61 +21,56 @@ using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
__global__ void __launch_bounds__(NTHREADS)
__global__ void __launch_bounds__(128) // 4 warps minimum for Layout D
test_umma_qk_hd16(const bf16_t* q, const bf16_t* k,
float* s_out, float* s_scalar, float scale)
{
const int tid = threadIdx.x;
const int wid = tid / WARP, lane = tid % WARP;
const int wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK = sQ + 128 * 16 + 4096;
bf16_t* sK = sQ + 128 * 16 + 4096; // 8KB padding after Q
float* sQ_row = (float*)(sK + 128 * 16);
for (int d = tid; d < 16; d += NTHREADS) sQ_row[d] = bf16_to_f32(q[d]);
for (int d = tid; d < 16; d += 128) sQ_row[d] = bf16_to_f32(q[d]);
// TMEM alloc (128 cols)
if (wid == 0) {
if (wid == 1) {
tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
}
__syncthreads();
uint32_t tb = *sTmemBase;
// Load Q and K
// Load Q and K into SMEM in canonical layout
write_q_to_smem<16>(sQ, q);
write_k_to_smem<128, 16>(sK, k);
// Zero the padding between Q and K
bf16_t* sQ_pad = sQ + 128 * 16; // After Q data
bf16_t* sK_end = sK + 128 * 16; // After K data
for (int i = tid; i < 4096; i += NTHREADS) sQ_pad[i] = 0; // 8KB padding
// Zero padding
bf16_t* sQ_pad = sQ + 128 * 16;
for (int i = tid; i < 4096; i += 128) sQ_pad[i] = 0;
__syncthreads();
// Descriptors
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
uint32_t sK_smem = __cvta_generic_to_shared(sK);
// Try LBO = 32 (128/4 * 16 / 16 = 32 in 16B units)
// Hypothesis: M=128 has 4 sub-tiles, each with 32 rows
// So LBO should be 32 * 16 = 512 bytes (32 in 16B units)
uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 32); // LBO = 32 * 16 = 512B
uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 32);
uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128);
uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128);
uint32_t idesc = make_idesc(128, 128);
// MMA — 1 thread calls (4× scaling is expected for M=128 cta_group::1)
// MMA — 1 thread issues (following gau-nernst pattern)
if (tid == 0) {
umma_ss_f16(tb, desc_q, desc_k, idesc, false);
}
__syncwarp();
if (wid == 0 && lane == 0) tmem_fence_store();
// tcgen05.fence::after_thread_sync (CRITICAL — correct MMA→TMEM load fence)
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Read from TMEM using Layout D: 32x32b.x8 format
// Read from TMEM using Layout D: 32x32b.x8 (gau-nernst pattern)
// Each warp reads 32 rows × 8 columns
// Warp 0: rows 0-31, warp 1: rows 32-63, warp 2: rows 64-95, warp 3: rows 96-127
if (wid < 4) {
for (int n = 0; n < 128 / 8; n++) {
const int row = wid * 32;
const int col = 0;
const int col = n * 8;
const int addr = tb + (row << 16) + col;
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];"
@@ -80,10 +78,15 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k,
"=f"(tmp[4]), "=f"(tmp[5]), "=f"(tmp[6]), "=f"(tmp[7])
: "r"(addr));
asm volatile("tcgen05.wait::ld.sync.aligned;");
// Lane 0 gets S[wid*32 + 0, 0..7] in tmp[0..7]
if (lane == 0) {
// Lane 0 writes first 8 output values for this warp's row
// Each lane gets values for its row (tid = wid*32 + lane)
int out_row = wid * 32 + lane;
if (lane == 0 && n < 1) { // Only first 8 cols for debug
for (int c = 0; c < 8; c++) {
s_out[wid * 8 + c] = tmp[c] * 0.25f; // Divide by 4 (Layout D scaling)
if (out_row < 128) {
s_out[out_row * 8 + c] = tmp[c];
}
}
}
}
@@ -100,17 +103,18 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k,
}
__syncthreads();
// TMEM dealloc
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== UMMA QK GEMM Test (Layout D read) ===\n");
printf("=== UMMA QK GEMM (gau-nernst pattern, 4 warps) ===\n");
const int HD = 16, SK = 128;
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)malloc(SK * HD * sizeof(bf16_t));
float* h_s_out = (float*)calloc(256, sizeof(float));
float* h_s_out = (float*)calloc(128*8, sizeof(float));
float* h_s_scalar = (float*)calloc(SK, sizeof(float));
srand(42);
@@ -119,33 +123,42 @@ int main() {
bf16_t *d_q, *d_k; float *d_s_out, *d_s_scalar;
cudaMalloc(&d_q, HD*sizeof(bf16_t)); cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
cudaMalloc(&d_s_out, 256*sizeof(float)); cudaMalloc(&d_s_scalar, SK*sizeof(float));
cudaMalloc(&d_s_out, 128*8*sizeof(float)); cudaMalloc(&d_s_scalar, SK*sizeof(float));
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemset(d_s_out, 0, 256*sizeof(float)); cudaMemset(d_s_scalar, 0, SK*sizeof(float));
cudaMemset(d_s_out, 0, 128*8*sizeof(float)); cudaMemset(d_s_scalar, 0, SK*sizeof(float));
int smem = (4 + 16 + 128*16*2 + 4096*2 + 128*16*2 + 16*4 + 256 + 127) & ~127;
test_umma_qk_hd16<<<1, NTHREADS, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE);
test_umma_qk_hd16<<<1, 128, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
cudaMemcpy(h_s_out, d_s_out, 256*sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(h_s_out, d_s_out, 128*8*sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(h_s_scalar, d_s_scalar, SK*sizeof(float), cudaMemcpyDeviceToHost);
// Print S[0,0] from each warp's perspective
printf("S from TMEM (Layout D, 32x32b.x8):\n");
for (int w = 0; w < 4; w++) {
printf(" Warp %d (rows %d-%d): ", w, w*32, w*32+31);
for (int c = 0; c < 8; c++) printf("%.4f ", h_s_out[w*8+c]);
printf("\n");
}
printf("S[0,0..7] scalar: ");
for (int c = 0; c < 8; c++) printf("%.4f ", h_s_scalar[c]);
// Print S[0,0..7] from each warp's row 0
printf("Row 0 (MMA): ");
for (int c = 0; c < 8; c++) printf("%.6f ", h_s_out[0*8+c]);
printf("\nRow 0 scalar: ");
for (int c = 0; c < 8; c++) printf("%.6f ", h_s_scalar[c]);
printf("\n");
printf("Ratio warp0/col0 vs scalar[0]: %.4f\n",
h_s_scalar[0] != 0 ? h_s_out[0] / h_s_scalar[0] : 0);
float max_diff = 0.0f, max_val = 0.0f;
for (int c = 0; c < 8; c++) {
max_diff = fmaxf(max_diff, fabsf(h_s_out[0*8+c] - h_s_scalar[c]));
max_val = fmaxf(max_val, fabsf(h_s_scalar[c]));
}
float rel_err = max_val > 0 ? max_diff / max_val : max_diff;
printf("Row 0 max rel err: %.6f\n", rel_err);
// Print a few more rows
for (int r : {32, 64, 96}) {
printf("Row %d: %.6f %.6f %.6f %.6f\n", r,
h_s_out[r*8+0], h_s_out[r*8+1], h_s_out[r*8+2], h_s_out[r*8+3]);
}
printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar);
free(h_q); free(h_k); free(h_s_out); free(h_s_scalar);
return 0;