From 91b03bd6bdb3eeb57f4bb19fb5614ab89b5ec671 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 22:57:59 +0000 Subject: [PATCH] test: verify 4-warp TMEM read with 32x32b.x8 after MMA --- tests/unit/test_tmem_4warp_read.cu | 215 +++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 tests/unit/test_tmem_4warp_read.cu diff --git a/tests/unit/test_tmem_4warp_read.cu b/tests/unit/test_tmem_4warp_read.cu new file mode 100644 index 00000000..39e29321 --- /dev/null +++ b/tests/unit/test_tmem_4warp_read.cu @@ -0,0 +1,215 @@ +/** + * Test: verify that 4 warps reading TMEM with 32x32b.x8 each see different rows. + * + * Write a known pattern to TMEM via UMMA (like QK GEMM), + * then have 4 warps read with 32x32b.x8 and check which rows they see. + * + * If warp w naturally sees rows [w*32, (w+1)*32), then multi-row softmax + * works with 4 warps using 32x32b.x8 and NO row offset. + */ + +#include +#include +#include +#include + +using bf16_t = unsigned short; + +__device__ __forceinline__ bf16_t f32_to_bf16(float f) { + bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h; +} +__device__ __forceinline__ float bf16_to_f32(bf16_t h) { + float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f; +} + +__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) { + asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + :: "r"(smem_ptr), "r"(num_cols)); +} +__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) { + asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;" + :: "r"(tmem_ptr), "r"(num_cols)); +} + +// Write SMEM in canonical layout for (128, 16) BF16 +__device__ void write_smem_canonical_128x16(bf16_t* dst, const bf16_t* src, int rows, int cols) { + const int N = 192; // total threads + const int tid = threadIdx.x; + constexpr int CORES_MN = 16; // 128/8 + for (int i = tid; i < 128 * 16; i += N) { + int r = i / 16, c = i % 16; + int core_mn = r / 8, local_r = r % 8; + int core_k = c / 8, local_c = c % 8; + int dst_idx = core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c; + dst[dst_idx] = (r < rows && c < cols) ? src[r * cols + c] : f32_to_bf16(0.0f); + } +} + +// UMMA descriptor + MMA +__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; } +__device__ __forceinline__ uint64_t make_umma_desc(uint32_t smem_addr, int block_mn) { + uint64_t desc = 0; + desc |= desc_encode(smem_addr) & 0x3FFF; + desc |= (desc_encode((uint64_t)block_mn * 16) & 0x3FFF) << 16; + desc |= (desc_encode(128ULL) & 0x3FFF) << 32; + desc |= 1ULL << 46; + return desc; +} +__device__ __forceinline__ uint32_t make_idesc(int block_m, int block_n) { + return (1U << 4) | (1U << 7) | (1U << 10) | + ((uint32_t)(block_n >> 3) << 17) | ((uint32_t)(block_m >> 4) << 24); +} +__device__ void umma_ss(uint32_t tmem_c, uint64_t da, uint64_t db, uint32_t idesc, bool acc) { + uint32_t sc = acc ? 0x3F800000u : 0u; + asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5,%6,%7,%8}, p;\n\t}" + :: "r"(tmem_c), "l"(da), "l"(db), "r"(idesc), "r"(sc), "r"(0), "r"(0), "r"(0), "r"(0)); +} + +__global__ void __launch_bounds__(192) +test_tmem_4warp_read(float* results) { + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + const int TMEM_N = 128; + + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 256) + 127) & ~(uintptr_t)127); + bf16_t* sB = sA + 128 * 16; // (128, 16) BF16 canonical + + // Create a known pattern: A[row, 0] = row (each row has a unique value in column 0) + // A is (128, 16), B is (128, 16) identity-like + // After MMA: S[row, col] = sum_d A[row, d] * B[col, d] + // If A[row, 0] = row and B has identity structure, S[row, col] should have row info + + // Fill A: row r has A[r, 0] = r (just column 0) + if (wid == 5) { + bf16_t tmp[128 * 16]; + for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f); + for (int r = 0; r < 128; r++) tmp[r * 16 + 0] = f32_to_bf16((float)r); + write_smem_canonical_128x16(sA, tmp, 128, 16); + + // Fill B: identity-like, B[col, 0] = 1 for col < 128 + for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f); + for (int c = 0; c < 128; c++) tmp[c * 16 + 0] = f32_to_bf16(1.0f); + write_smem_canonical_128x16(sB, tmp, 128, 16); + } + __syncthreads(); + + // TMEM alloc (warp 4) + if (wid == 4) { + uint32_t sp = __cvta_generic_to_shared(sTmemBase); + tmem_alloc(sp, TMEM_N); + } + __syncthreads(); + uint32_t tb = *sTmemBase; + + // MMA: A(128,16) × B(128,16) → S(128,128) in TMEM + // S[row, col] should be row if col < 128 (from the identity pattern) + if (wid == 4) { + uint32_t idesc = make_idesc(128, 128); + uint64_t da = make_umma_desc(__cvta_generic_to_shared(sA), 128); + uint64_t db = make_umma_desc(__cvta_generic_to_shared(sB), 128); + if (tid == 128) umma_ss(tb, da, db, idesc, false); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + + // Now have 4 warps read TMEM with 32x32b.x8 + // Each warp reads column group 0 (columns 0-7) + // Lane l reports what value it sees + if (wid < 4) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + 0)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + + // Lane 0 from each warp stores its 8 values + if (lane == 0) { + for (int c = 0; c < 8; c++) { + results[wid * 8 + c] = tmp[c]; + } + } + // Lane 1 from each warp + if (lane == 1) { + for (int c = 0; c < 8; c++) { + results[32 + wid * 8 + c] = tmp[c]; + } + } + // Lane 16 from each warp (to check if warps see different rows) + if (lane == 16) { + for (int c = 0; c < 8; c++) { + results[64 + wid * 8 + c] = tmp[c]; + } + } + } + __syncthreads(); + + if (wid == 4) tmem_dealloc(tb, TMEM_N); +} + +int main() { + printf("TMEM 4-warp read test\n"); + printf("=====================\n"); + printf("After MMA, 4 warps read TMEM with 32x32b.x8.\n"); + printf("If each warp sees different rows, multi-row softmax works.\n\n"); + + float* d_results; + cudaMalloc(&d_results, 128 * sizeof(float)); + cudaMemset(d_results, 0, 128 * sizeof(float)); + + int smem = 256 + 128 + 128*16*2*2 + 512; // sbuf + sA + sB + slack + test_tmem_4warp_read<<<1, 192, smem>>>(d_results); + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); + printf("Kernel crashed or hung!\n"); + cudaFree(d_results); + return 1; + } + + float h_results[128]; + cudaMemcpy(h_results, d_results, 128 * sizeof(float), cudaMemcpyDeviceToHost); + + printf("Lane 0 from each warp (expecting: warp 0=0.0, warp 1=32.0, warp 2=64.0, warp 3=96.0):\n"); + for (int w = 0; w < 4; w++) { + printf(" Warp %d: [0]=%.1f [1]=%.1f [2]=%.1f [3]=%.1f ... [7]=%.1f\n", + w, h_results[w*8+0], h_results[w*8+1], h_results[w*8+2], h_results[w*8+3], h_results[w*8+7]); + } + + printf("\nLane 1 from each warp (expecting: warp 0=1.0, warp 1=33.0, warp 2=65.0, warp 3=97.0):\n"); + for (int w = 0; w < 4; w++) { + printf(" Warp %d: [0]=%.1f [1]=%.1f ... [7]=%.1f\n", + w, h_results[32+w*8+0], h_results[32+w*8+1], h_results[32+w*8+7]); + } + + printf("\nLane 16 from each warp (expecting: warp 0=16.0, warp 1=48.0, warp 2=80.0, warp 3=112.0):\n"); + for (int w = 0; w < 4; w++) { + printf(" Warp %d: [0]=%.1f [1]=%.1f ... [7]=%.1f\n", + w, h_results[64+w*8+0], h_results[64+w*8+1], h_results[64+w*8+7]); + } + + // Verify: lane 0 of warp 0 should see row 0's value + // If S[row, col] = row (from our identity pattern), then: + // Warp 0, lane 0, col 0 → 0.0 + // Warp 1, lane 0, col 0 → 32.0 (if warps see different rows) + // Warp 0, lane 0, col 0 → 0.0 (if all warps see same rows) + int pass = 1; + if (h_results[0] == 0.0f && h_results[8] == 0.0f && h_results[16] == 0.0f && h_results[24] == 0.0f) { + printf("\nAll warps see the SAME rows (rows 0-31). Multi-warp softmax needs a different approach.\n"); + } else if (h_results[0] == 0.0f && h_results[8] == 32.0f) { + printf("\nEach warp sees DIFFERENT rows! Multi-warp softmax works with 32x32b.x8!\n"); + pass = 1; + } else { + printf("\nUnexpected pattern — need further investigation.\n"); + printf(" Warp0[0]=%.1f Warp1[0]=%.1f Warp2[0]=%.1f Warp3[0]=%.1f\n", + h_results[0], h_results[8], h_results[16], h_results[24]); + } + + cudaFree(d_results); + return 0; +}