test: gau-nernst pattern — fence::after_thread_sync, 4 warps, 128 threads, 32x32b.x8 loop
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
/**
|
||||
* UMMA QK GEMM Test (HD=16, SK=128)
|
||||
* Uses Layout D TMEM read (32x32b.x8, 4 warps) to correctly read MMA output.
|
||||
* Debugging the 4× scaling factor observed with 16x256b read.
|
||||
* Using gau-nernst's exact MMA + epilogue pattern:
|
||||
* - tcgen05.fence::after_thread_sync before TMEM read
|
||||
* - 32x32b.x8 TMEM read with row = warp_id * 32
|
||||
* - tcgen05.wait::ld.sync.aligned after each read
|
||||
* - Output: off_m + tid per thread (each thread = one row)
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@@ -18,61 +21,56 @@ using namespace dsv4::kernels::attention;
|
||||
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
||||
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
||||
|
||||
__global__ void __launch_bounds__(NTHREADS)
|
||||
__global__ void __launch_bounds__(128) // 4 warps minimum for Layout D
|
||||
test_umma_qk_hd16(const bf16_t* q, const bf16_t* k,
|
||||
float* s_out, float* s_scalar, float scale)
|
||||
{
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / WARP, lane = tid % WARP;
|
||||
const int wid = tid / 32, lane = tid % 32;
|
||||
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||||
bf16_t* sK = sQ + 128 * 16 + 4096;
|
||||
bf16_t* sK = sQ + 128 * 16 + 4096; // 8KB padding after Q
|
||||
float* sQ_row = (float*)(sK + 128 * 16);
|
||||
|
||||
for (int d = tid; d < 16; d += NTHREADS) sQ_row[d] = bf16_to_f32(q[d]);
|
||||
for (int d = tid; d < 16; d += 128) sQ_row[d] = bf16_to_f32(q[d]);
|
||||
|
||||
// TMEM alloc (128 cols)
|
||||
if (wid == 0) {
|
||||
if (wid == 1) {
|
||||
tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Load Q and K
|
||||
// Load Q and K into SMEM in canonical layout
|
||||
write_q_to_smem<16>(sQ, q);
|
||||
write_k_to_smem<128, 16>(sK, k);
|
||||
// Zero the padding between Q and K
|
||||
bf16_t* sQ_pad = sQ + 128 * 16; // After Q data
|
||||
bf16_t* sK_end = sK + 128 * 16; // After K data
|
||||
for (int i = tid; i < 4096; i += NTHREADS) sQ_pad[i] = 0; // 8KB padding
|
||||
// Zero padding
|
||||
bf16_t* sQ_pad = sQ + 128 * 16;
|
||||
for (int i = tid; i < 4096; i += 128) sQ_pad[i] = 0;
|
||||
__syncthreads();
|
||||
|
||||
// Descriptors
|
||||
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
|
||||
uint32_t sK_smem = __cvta_generic_to_shared(sK);
|
||||
// Try LBO = 32 (128/4 * 16 / 16 = 32 in 16B units)
|
||||
// Hypothesis: M=128 has 4 sub-tiles, each with 32 rows
|
||||
// So LBO should be 32 * 16 = 512 bytes (32 in 16B units)
|
||||
uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 32); // LBO = 32 * 16 = 512B
|
||||
uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 32);
|
||||
uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128);
|
||||
uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128);
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
|
||||
// MMA — 1 thread calls (4× scaling is expected for M=128 cta_group::1)
|
||||
// MMA — 1 thread issues (following gau-nernst pattern)
|
||||
if (tid == 0) {
|
||||
umma_ss_f16(tb, desc_q, desc_k, idesc, false);
|
||||
}
|
||||
__syncwarp();
|
||||
if (wid == 0 && lane == 0) tmem_fence_store();
|
||||
// tcgen05.fence::after_thread_sync (CRITICAL — correct MMA→TMEM load fence)
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Read from TMEM using Layout D: 32x32b.x8 format
|
||||
// Read from TMEM using Layout D: 32x32b.x8 (gau-nernst pattern)
|
||||
// Each warp reads 32 rows × 8 columns
|
||||
// Warp 0: rows 0-31, warp 1: rows 32-63, warp 2: rows 64-95, warp 3: rows 96-127
|
||||
if (wid < 4) {
|
||||
for (int n = 0; n < 128 / 8; n++) {
|
||||
const int row = wid * 32;
|
||||
const int col = 0;
|
||||
const int col = n * 8;
|
||||
const int addr = tb + (row << 16) + col;
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];"
|
||||
@@ -80,10 +78,15 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k,
|
||||
"=f"(tmp[4]), "=f"(tmp[5]), "=f"(tmp[6]), "=f"(tmp[7])
|
||||
: "r"(addr));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
// Lane 0 gets S[wid*32 + 0, 0..7] in tmp[0..7]
|
||||
if (lane == 0) {
|
||||
|
||||
// Lane 0 writes first 8 output values for this warp's row
|
||||
// Each lane gets values for its row (tid = wid*32 + lane)
|
||||
int out_row = wid * 32 + lane;
|
||||
if (lane == 0 && n < 1) { // Only first 8 cols for debug
|
||||
for (int c = 0; c < 8; c++) {
|
||||
s_out[wid * 8 + c] = tmp[c] * 0.25f; // Divide by 4 (Layout D scaling)
|
||||
if (out_row < 128) {
|
||||
s_out[out_row * 8 + c] = tmp[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -100,17 +103,18 @@ test_umma_qk_hd16(const bf16_t* q, const bf16_t* k,
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// TMEM dealloc
|
||||
if (wid == 0) tmem_dealloc(tb, 128);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("=== UMMA QK GEMM Test (Layout D read) ===\n");
|
||||
printf("=== UMMA QK GEMM (gau-nernst pattern, 4 warps) ===\n");
|
||||
const int HD = 16, SK = 128;
|
||||
const float SCALE = 1.0f / sqrtf((float)HD);
|
||||
|
||||
bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t));
|
||||
bf16_t* h_k = (bf16_t*)malloc(SK * HD * sizeof(bf16_t));
|
||||
float* h_s_out = (float*)calloc(256, sizeof(float));
|
||||
float* h_s_out = (float*)calloc(128*8, sizeof(float));
|
||||
float* h_s_scalar = (float*)calloc(SK, sizeof(float));
|
||||
|
||||
srand(42);
|
||||
@@ -119,33 +123,42 @@ int main() {
|
||||
|
||||
bf16_t *d_q, *d_k; float *d_s_out, *d_s_scalar;
|
||||
cudaMalloc(&d_q, HD*sizeof(bf16_t)); cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
|
||||
cudaMalloc(&d_s_out, 256*sizeof(float)); cudaMalloc(&d_s_scalar, SK*sizeof(float));
|
||||
cudaMalloc(&d_s_out, 128*8*sizeof(float)); cudaMalloc(&d_s_scalar, SK*sizeof(float));
|
||||
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemset(d_s_out, 0, 256*sizeof(float)); cudaMemset(d_s_scalar, 0, SK*sizeof(float));
|
||||
cudaMemset(d_s_out, 0, 128*8*sizeof(float)); cudaMemset(d_s_scalar, 0, SK*sizeof(float));
|
||||
|
||||
int smem = (4 + 16 + 128*16*2 + 4096*2 + 128*16*2 + 16*4 + 256 + 127) & ~127;
|
||||
test_umma_qk_hd16<<<1, NTHREADS, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE);
|
||||
test_umma_qk_hd16<<<1, 128, smem>>>(d_q, d_k, d_s_out, d_s_scalar, SCALE);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||||
|
||||
cudaMemcpy(h_s_out, d_s_out, 256*sizeof(float), cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(h_s_out, d_s_out, 128*8*sizeof(float), cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(h_s_scalar, d_s_scalar, SK*sizeof(float), cudaMemcpyDeviceToHost);
|
||||
|
||||
// Print S[0,0] from each warp's perspective
|
||||
printf("S from TMEM (Layout D, 32x32b.x8):\n");
|
||||
for (int w = 0; w < 4; w++) {
|
||||
printf(" Warp %d (rows %d-%d): ", w, w*32, w*32+31);
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", h_s_out[w*8+c]);
|
||||
printf("\n");
|
||||
}
|
||||
printf("S[0,0..7] scalar: ");
|
||||
for (int c = 0; c < 8; c++) printf("%.4f ", h_s_scalar[c]);
|
||||
// Print S[0,0..7] from each warp's row 0
|
||||
printf("Row 0 (MMA): ");
|
||||
for (int c = 0; c < 8; c++) printf("%.6f ", h_s_out[0*8+c]);
|
||||
printf("\nRow 0 scalar: ");
|
||||
for (int c = 0; c < 8; c++) printf("%.6f ", h_s_scalar[c]);
|
||||
printf("\n");
|
||||
printf("Ratio warp0/col0 vs scalar[0]: %.4f\n",
|
||||
h_s_scalar[0] != 0 ? h_s_out[0] / h_s_scalar[0] : 0);
|
||||
|
||||
float max_diff = 0.0f, max_val = 0.0f;
|
||||
for (int c = 0; c < 8; c++) {
|
||||
max_diff = fmaxf(max_diff, fabsf(h_s_out[0*8+c] - h_s_scalar[c]));
|
||||
max_val = fmaxf(max_val, fabsf(h_s_scalar[c]));
|
||||
}
|
||||
float rel_err = max_val > 0 ? max_diff / max_val : max_diff;
|
||||
printf("Row 0 max rel err: %.6f\n", rel_err);
|
||||
|
||||
// Print a few more rows
|
||||
for (int r : {32, 64, 96}) {
|
||||
printf("Row %d: %.6f %.6f %.6f %.6f\n", r,
|
||||
h_s_out[r*8+0], h_s_out[r*8+1], h_s_out[r*8+2], h_s_out[r*8+3]);
|
||||
}
|
||||
|
||||
printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED");
|
||||
cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar);
|
||||
free(h_q); free(h_k); free(h_s_out); free(h_s_scalar);
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user