From be45e87891b45e3428082edda1c6ea55d3859273 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 23:00:27 +0000 Subject: [PATCH] =?UTF-8?q?test:=20MMA=E2=86=924-warp=20TMEM=20read=20?= =?UTF-8?q?=E2=80=94=20do=20warps=20see=20different=20rows=3F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_tmem_4warp_read.cu | 190 ++++++++++++++++++----------- 1 file changed, 122 insertions(+), 68 deletions(-) diff --git a/tests/unit/test_tmem_4warp_read.cu b/tests/unit/test_tmem_4warp_read.cu index ec70d33d..a3f26922 100644 --- a/tests/unit/test_tmem_4warp_read.cu +++ b/tests/unit/test_tmem_4warp_read.cu @@ -1,7 +1,14 @@ /** - * Test: TMEM cross-warp visibility after 32x32b.x8 store. - * Try different synchronization strategies to make TMEM stores - * visible to all warps. + * Test: Can 4 warps read TMEM data written by UMMA (not 32x32b.x8 store)? + * + * The MMA writes to TMEM via the tcgen05.mma instruction, which uses + * the full 128-thread hardware pipeline. This should produce TMEM data + * that's visible to all warps in the CTA. + * + * If this works, then multi-row softmax is possible: + * - QK MMA writes S to TMEM (hardware pipeline) + * - 4 warps read S with 32x32b.x8 (each sees 32 rows) + * - The question is: does each warp see DIFFERENT 32 rows, or the same 32? */ #include @@ -14,6 +21,9 @@ using bf16_t = unsigned short; __device__ __forceinline__ bf16_t f32_to_bf16(float f) { bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h; } +__device__ __forceinline__ float bf16_to_f32(bf16_t h) { + float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f; +} __device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) { asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" @@ -24,9 +34,42 @@ __device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) { :: "r"(tmem_ptr), "r"(num_cols)); } -// Strategy: warp 0 stores, then ALL warps fence + sync, then read -__global__ void __launch_bounds__(128) -test_tmem_sync(float* results, int strategy) { +__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; } +__device__ __forceinline__ uint64_t make_umma_desc(uint32_t smem_addr, int block_mn) { + uint64_t desc = 0; + desc |= desc_encode(smem_addr) & 0x3FFF; + desc |= (desc_encode((uint64_t)block_mn * 16) & 0x3FFF) << 16; + desc |= (desc_encode(128ULL) & 0x3FFF) << 32; + desc |= 1ULL << 46; + return desc; +} +__device__ __forceinline__ uint32_t make_idesc(int block_m, int block_n) { + return (1U << 4) | (1U << 7) | (1U << 10) | + ((uint32_t)(block_n >> 3) << 17) | ((uint32_t)(block_m >> 4) << 24); +} +__device__ void umma_ss(uint32_t tmem_c, uint64_t da, uint64_t db, uint32_t idesc, bool acc) { + uint32_t sc = acc ? 0x3F800000u : 0u; + asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5,%6,%7,%8}, p;\n\t}" + :: "r"(tmem_c), "l"(da), "l"(db), "r"(idesc), "r"(sc), "r"(0), "r"(0), "r"(0), "r"(0)); +} + +// Write (rows, 16) BF16 matrix to SMEM canonical layout +__device__ void write_canonical(bf16_t* dst, const bf16_t* src, int rows, int cols16) { + constexpr int CORES_MN = 16; + for (int i = threadIdx.x; i < 128 * cols16; i += 192) { + int r = i / cols16, c = i % cols16; + if (r < rows && c < cols16) { + int core_mn = r / 8, local_r = r % 8; + int core_k = c / 8, local_c = c % 8; + int dst_idx = core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c; + dst[dst_idx] = src[r * cols16 + c]; + } + } +} + +__global__ void __launch_bounds__(192) +test_mma_4warp_read(float* results) { const int tid = threadIdx.x; const int wid = tid / 32; const int lane = tid % 32; @@ -34,44 +77,48 @@ test_tmem_sync(float* results, int strategy) { extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; + bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 256) + 127) & ~(uintptr_t)127); + bf16_t* sB = sA + 128 * 16; - if (wid == 0) { + // TMEM alloc (warp 4) + if (wid == 4) { uint32_t sp = __cvta_generic_to_shared(sTmemBase); tmem_alloc(sp, TMEM_N); } __syncthreads(); uint32_t tb = *sTmemBase; - // Warp 0 stores data - if (wid == 0) { - float vals[8]; - for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c); - uint32_t ivals[8]; - for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4); + // Build A such that A[row, 0] = row+1 (row 0 = 1.0, row 1 = 2.0, etc.) + // This way we can identify which row each lane is reading + if (wid == 5) { // load warp + bf16_t tmp[128 * 16]; + for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f); + for (int r = 0; r < 128; r++) tmp[r * 16 + 0] = f32_to_bf16((float)(r + 1)); + write_canonical(sA, tmp, 128, 16); - asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" - :: "r"(tb + 0), - "r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]), - "r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7])); - asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); + // B: B[col, 0] = 1.0 for col < 128 + for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f); + for (int c = 0; c < 128; c++) tmp[c * 16 + 0] = f32_to_bf16(1.0f); + write_canonical(sB, tmp, 128, 16); } + __syncthreads(); - // Strategy 0: just __syncthreads - // Strategy 1: __syncthreads + tcgen05.wait::st on all warps - // Strategy 2: __syncthreads + fence + __syncthreads - if (strategy == 0) { - __syncthreads(); - } else if (strategy == 1) { - __syncthreads(); - asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); - __syncthreads(); - } else if (strategy == 2) { - __syncthreads(); - asm volatile("fence.sc.gpu;" ::: "memory"); - __syncthreads(); + // MMA: A(128,16) × B(128,16) → S(128,128) in TMEM + // S[row, col] should be (row+1) for col < 128 + if (wid == 4) { + uint32_t idesc = make_idesc(128, 128); + uint64_t da = make_umma_desc(__cvta_generic_to_shared(sA), 128); + uint64_t db = make_umma_desc(__cvta_generic_to_shared(sB), 128); + if (tid == 128) umma_ss(tb, da, db, idesc, false); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); } + __syncthreads(); + // Extra fence for TMEM visibility + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); - // All 4 warps read + // 4 warps read TMEM with 32x32b.x8 + // Read column group 0 (columns 0-7) if (wid < 4) { float tmp[8]; asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" @@ -80,53 +127,60 @@ test_tmem_sync(float* results, int strategy) { : "r"(tb + 0)); asm volatile("tcgen05.wait::ld.sync.aligned;"); - // Lane 0: store col 0 value + // Lane 0: store col 0 value (should be row_id+1) if (lane == 0) results[wid] = tmp[0]; - // Lane 5: store col 0 value (should be 50.0) - if (lane == 5) results[4 + wid] = tmp[0]; + // Lane 1 + if (lane == 1) results[4 + wid] = tmp[0]; + // Lane 15 + if (lane == 15) results[8 + wid] = tmp[0]; + // Lane 31 + if (lane == 31) results[12 + wid] = tmp[0]; + + // Lane 0, also store col 1 value + if (lane == 0) results[16 + wid] = tmp[1]; } __syncthreads(); - if (wid == 0) tmem_dealloc(tb, TMEM_N); + if (wid == 4) tmem_dealloc(tb, TMEM_N); } int main() { - printf("TMEM cross-warp visibility test\n"); - printf("================================\n\n"); + printf("MMA → 4-warp TMEM read test\n"); + printf("============================\n"); + printf("After UMMA, S[row, col] = row+1 for col < 128.\n"); + printf("If 4 warps see different rows, lane l in warp w should see row w*32+l.\n\n"); - for (int strat = 0; strat < 3; strat++) { - const char* names[] = {"__syncthreads only", - "__syncthreads + tcgen05.wait::st + __syncthreads", - "__syncthreads + fence.sc.gpu + __syncthreads"}; - printf("Strategy %d: %s\n", strat, names[strat]); + float* d_r; + cudaMalloc(&d_r, 32 * sizeof(float)); + cudaMemset(d_r, 0, 32 * sizeof(float)); + test_mma_4warp_read<<<1, 192, 4096>>>(d_r); - float* d_r; - cudaMalloc(&d_r, 32 * sizeof(float)); - cudaMemset(d_r, 0, 32 * sizeof(float)); - test_tmem_sync<<<1, 128, 256>>>(d_r, strat); - - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - printf(" CUDA ERROR: %s\n", cudaGetErrorString(err)); - } else { - float h[32]; - cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost); - printf(" Lane 0, col 0 (expect 0.0 for all): "); - for(int w=0;w<4;w++) printf("w%d=%.0f ", w, h[w]); - printf("\n"); - printf(" Lane 5, col 0 (expect 50.0 for all): "); - for(int w=0;w<4;w++) printf("w%d=%.0f ", w, h[4+w]); - printf("\n"); - - int ok = 1; - for(int w=0;w<4;w++) { - if (fabsf(h[w] - 0.0f) > 0.01f) ok = 0; - if (fabsf(h[4+w] - 50.0f) > 0.01f) ok = 0; - } - printf(" Result: %s\n\n", ok ? "ALL WARPS SEE DATA" : "SOME WARPS SEE ZEROS"); - } + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); cudaFree(d_r); + return 1; } + float h[32]; + cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost); + + printf("Lane 0, col 0 (expect row_id+1): "); + printf("w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[0], h[1], h[2], h[3]); + printf(" If w0=1.0 w1=33.0 w2=65.0 w3=97.0 → each warp sees different 32 rows ✓\n"); + printf(" If all same → all warps see same 32 rows\n\n"); + + printf("Lane 1, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[4], h[5], h[6], h[7]); + printf("Lane 15, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[8], h[9], h[10], h[11]); + printf("Lane 31, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[12], h[13], h[14], h[15]); + printf("Lane 0, col 1: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[16], h[17], h[18], h[19]); + + // Check if warps see different rows + int diff = (h[0] != h[1]) || (h[1] != h[2]) || (h[2] != h[3]); + printf("\nConclusion: %s\n", diff ? + "Warps see DIFFERENT rows — multi-warp softmax works!" : + "Warps see SAME rows — need alternative approach"); + + cudaFree(d_r); return 0; }