From fee022a485d95a4e80ebf0ba4efdfde0ab90bc6d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 23:05:29 +0000 Subject: [PATCH] =?UTF-8?q?test:=20MMA=E2=86=924-warp=20read=20using=20pro?= =?UTF-8?q?ven=20fmha=5Fcommon+umma=5Fdesc=20infra?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_tmem_4warp_read.cu | 242 +++++++++++++++++------------ 1 file changed, 140 insertions(+), 102 deletions(-) diff --git a/tests/unit/test_tmem_4warp_read.cu b/tests/unit/test_tmem_4warp_read.cu index 954735ab..3a3b8b97 100644 --- a/tests/unit/test_tmem_4warp_read.cu +++ b/tests/unit/test_tmem_4warp_read.cu @@ -1,138 +1,176 @@ /** - * Test: Can we do 16x256b.x1 LOADS multiple times without crashing? - * Uses uint32_t registers (not float) per the working test_tmem_minimal.cu pattern. + * MINIMAL test: After UMMA (QK GEMM), can 4 warps read different rows? + * + * Uses EXACTLY the same SMEM loading and MMA pattern as fmha_6warp_multihead.cuh + * (which is proven to work for T=1 decode). + * + * Then has 4 warps each do 32x32b.x8 read and report what they see. */ -#include -#include -#include -#include +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" -using bf16_t = unsigned short; +using namespace dsv4::kernels::attention; -__device__ __forceinline__ bf16_t f32_to_bf16(float f) { - bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h; -} +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } -__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) { - asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" - :: "r"(smem_ptr), "r"(num_cols)); -} -__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) { - asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;" - :: "r"(tmem_ptr), "r"(num_cols)); -} +constexpr int HD = 64; +constexpr int SK = 128; +constexpr int NKT_QK = HD / MMA_K_BF16; // 4 +constexpr int TILE_SZ = 128 * MMA_K_BF16; // 2048 +constexpr int TMEM_N = 128; +constexpr int CORES_MN = 16; -__global__ void __launch_bounds__(32) -test_16x256b_loads(float* results) { - const int lane = threadIdx.x; - const int TMEM_N = 128; +__global__ void __launch_bounds__(192) +test_mma_rows(float* results) { + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; + float* sRowMax = (float*)(sbuf + 8); + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowMax + 128) + 127) & ~(uintptr_t)127); + bf16_t* sK0 = sQ0 + TILE_SZ; - uint32_t sp = __cvta_generic_to_shared(sTmemBase); - tmem_alloc(sp, TMEM_N); - __syncwarp(); + // TMEM alloc + if (wid == 4) { + uint32_t sp = __cvta_generic_to_shared(sTmemBase); + tmem_alloc(sp, TMEM_N); + } + __syncthreads(); uint32_t tb = *sTmemBase; - // Write data via 32x32b.x8 (known working) - { - float vals[8]; - for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c); - uint32_t ivals[8]; - for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4); + // Build GMEM data and copy to device... can't in-kernel. + // Instead, build Q and K directly in SMEM using the load warp. + // Q: (T, HD) where T=128 (full tile), Q[r, d] = (r+1) for all d + // K: (128, HD), K[r, 0] = 1.0, rest zero + // Then S[row, col] = Q[row, 0]*K[col, 0] * 4_K_tiles = (row+1)*1.0*4 = 4*(row+1) + // (because Q[row, d]=row+1 for all d and K[col, 0]=1.0, and there are 4 K-tiles + // but K only has data in d=0 which is in K-tile 0, so S[row, col] = (row+1)*1.0 for col<128) - asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" - :: "r"(tb + 0), - "r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]), - "r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7])); - asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); + // Actually simpler: Q[row, 0] = row+1, rest zero. K[col, 0] = 1.0, rest zero. + // Then S[row, col] = (row+1) * 1.0 = row+1 for each (row, col). + + for (int kt = 0; kt < NKT_QK; kt++) { + // Load Q and K for this K-tile + if (wid == 5) { + for (int i = lane; i < TILE_SZ; i += 32) { sQ0[i] = 0; sK0[i] = 0; } + + // Q: only first K-tile (kt==0) has data at d=0 + if (kt == 0) { + for (int r = 0; r < 128; r++) { + int d = 0; + int ck = d / 8, lc = d % 8; + int core_mn = r / 8, local_r = r % 8; + sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] = f32_to_bf16((float)(r + 1)); + } + // K: only d=0 has data + for (int r = 0; r < SK; r++) { + int d = 0; + int ck = d / 8, lc = d % 8; + int tmn = r / 8, lr = r % 8; + sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(1.0f); + } + } + } + __syncthreads(); + + // MMA + if (wid == 4) { + uint32_t idesc = make_idesc(128, 128); + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128); + if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); } - // 16x256b.x1 loads using uint32_t registers (matching test_tmem_minimal.cu) - int load_count = 0; + // Extra sync for TMEM visibility + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); - // Load column 0 - { - uint32_t u0, u1, u2, u3; - asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" - : "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 0)); + // Now read TMEM with 32x32b.x8 from each of the 4 softmax warps + // Read column group 0 (columns 0-7) + if (wid < 4) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + 0)); asm volatile("tcgen05.wait::ld.sync.aligned;"); - load_count++; - float v0, v1, v2, v3; - memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4); - memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4); - - if (lane == 0) { results[0]=v0; results[1]=v1; results[2]=v2; results[3]=v3; } - if (lane == 1) { results[4]=v0; results[5]=v1; results[6]=v2; results[7]=v3; } + // Lane 0 from each warp: should see row 0's col 0 value + // Expected: S[0, 0] = 1.0 (row 0+1) + // But which row does each warp see? + if (lane == 0) results[wid] = tmp[0]; // col 0 + if (lane == 0) results[4 + wid] = tmp[1]; // col 1 + // Lane 1: row 1 + if (lane == 1) results[8 + wid] = tmp[0]; + // Lane 15: row 15 + if (lane == 15) results[12 + wid] = tmp[0]; + // Lane 31: row 31 + if (lane == 31) results[16 + wid] = tmp[0]; } + __syncthreads(); - // Load column 1 (2nd load — does it crash?) - { - uint32_t u0, u1, u2, u3; - asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" - : "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 4)); // next 4 columns - asm volatile("tcgen05.wait::ld.sync.aligned;"); - load_count++; - - float v0, v1, v2, v3; - memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4); - memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4); - - if (lane == 0) { results[8]=v0; results[9]=v1; results[10]=v2; results[11]=v3; } - } - - // Load column 8 (3rd load) - { - uint32_t u0, u1, u2, u3; - asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" - : "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 32)); // columns 8-11 - asm volatile("tcgen05.wait::ld.sync.aligned;"); - load_count++; - - float v0, v1, v2, v3; - memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4); - memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4); - - if (lane == 0) { results[12]=v0; results[13]=v1; results[14]=v2; results[15]=v3; } - } - - if (lane == 0) results[16] = (float)load_count; - tmem_dealloc(tb, TMEM_N); + if (wid == 4) tmem_dealloc(tb, TMEM_N); } int main() { - printf("16x256b.x1 multiple LOAD test\n"); - printf("==============================\n\n"); + printf("MMA → 4-warp 32x32b.x8 read\n"); + printf("=============================\n"); + printf("Q[r,0] = r+1, K[c,0] = 1.0 → S[r,c] = r+1\n"); + printf("Lane 0 in each warp should see row 0's value (1.0) if all warps see same rows,\n"); + printf("or different values if warps see different rows.\n\n"); float* d_r; cudaMalloc(&d_r, 32 * sizeof(float)); cudaMemset(d_r, 0, 32 * sizeof(float)); - test_16x256b_loads<<<1, 32, 256>>>(d_r); + + // SMEM: sbuf(8) + sRowMax(512) + align(128) + sQ0(4096) + sK0(4096) + slack(256) = 9000 + size_t smem_off = 8 + 128*4; + smem_off = ((smem_off + 127) & ~(size_t)127); + smem_off += TILE_SZ * 2 + 256; + int smem = (int)smem_off; + + test_mma_rows<<<1, 192, smem>>>(d_r); cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); - printf("16x256b.x1 loads CRASHED\n"); - cudaFree(d_r); - return 1; + printf("Kernel crashed!\n"); + } else { + float h[32]; + cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost); + + printf("Lane 0, col 0 (S[0,0]=1.0 if row 0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", + h[0], h[1], h[2], h[3]); + printf("Lane 0, col 1 (S[0,1]=1.0 if row 0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", + h[4], h[5], h[6], h[7]); + printf("Lane 1, col 0 (S[1,0]=2.0 if row 1): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", + h[8], h[9], h[10], h[11]); + printf("Lane 15, col 0 (S[15,0]=16.0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", + h[12], h[13], h[14], h[15]); + printf("Lane 31, col 0 (S[31,0]=32.0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", + h[16], h[17], h[18], h[19]); + + // Check if warps see different data + bool all_same = (h[0] == h[1]) && (h[1] == h[2]) && (h[2] == h[3]); + bool w0_correct = (fabsf(h[0] - 1.0f) < 0.01f); + + if (w0_correct && all_same) { + printf("\nAll warps see the SAME 32 rows (rows 0-31). Warp 0 data is correct.\n"); + printf("For T>32, we need an alternative to 32x32b.x8 for rows 32-127.\n"); + } else if (!all_same) { + printf("\nWarps see DIFFERENT rows! Multi-warp softmax is possible!\n"); + } else { + printf("\nUnexpected: warp 0 data incorrect (%.1f, expected 1.0)\n", h[0]); + } } - float h[32]; - cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost); - - printf("Load count: %d (3 loads completed = no crash)\n\n", (int)h[16]); - - printf("Col 0, lane 0 (expect rows 0-3: 0,10,20,30): %.1f %.1f %.1f %.1f\n", h[0], h[1], h[2], h[3]); - printf("Col 0, lane 1 (expect rows 4-7: 40,50,60,70): %.1f %.1f %.1f %.1f\n", h[4], h[5], h[6], h[7]); - printf("Col 1, lane 0 (expect rows 0-3 at col 1: 1,11,21,31): %.1f %.1f %.1f %.1f\n", h[8], h[9], h[10], h[11]); - printf("Col 8, lane 0 (expect row 0 col 8 = 8, row 1 col 8 = 18, etc): %.1f %.1f %.1f %.1f\n", h[12], h[13], h[14], h[15]); - - int pass = (fabsf(h[0]-0.0f)<0.01f) && (fabsf(h[1]-10.0f)<0.01f) && - (fabsf(h[2]-20.0f)<0.01f) && (fabsf(h[3]-30.0f)<0.01f) && (int)h[16]==3; - printf("\nResult: %s\n", pass ? "16x256b.x1 LOADS work multiple times!" : "Data mismatch or crash"); cudaFree(d_r); - return pass ? 0 : 1; + return 0; }