diff --git a/tests/unit/test_tmem_4warp_read.cu b/tests/unit/test_tmem_4warp_read.cu index e32ee1d4..66afc58e 100644 --- a/tests/unit/test_tmem_4warp_read.cu +++ b/tests/unit/test_tmem_4warp_read.cu @@ -1,14 +1,10 @@ /** - * Test: Can 4 warps read TMEM data written by UMMA (not 32x32b.x8 store)? + * Test: Can we do 16x256b.x1 LOADS multiple times without crashing? + * (The crash was on 16x256b.x1 STORES, not loads.) * - * The MMA writes to TMEM via the tcgen05.mma instruction, which uses - * the full 128-thread hardware pipeline. This should produce TMEM data - * that's visible to all warps in the CTA. - * - * If this works, then multi-row softmax is possible: - * - QK MMA writes S to TMEM (hardware pipeline) - * - 4 warps read S with 32x32b.x8 (each sees 32 rows) - * - The question is: does each warp see DIFFERENT 32 rows, or the same 32? + * If loads work multiple times, we can: + * - Use 16x256b.x1 for softmax reads (128 rows, 1 column per call) + * - Use 32x32b.x8 for everything else (stores, PV, epilogue) */ #include @@ -21,9 +17,6 @@ using bf16_t = unsigned short; __device__ __forceinline__ bf16_t f32_to_bf16(float f) { bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h; } -__device__ __forceinline__ float bf16_to_f32(bf16_t h) { - float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f; -} __device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) { asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" @@ -34,131 +27,125 @@ __device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) { :: "r"(tmem_ptr), "r"(num_cols)); } -__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; } -__device__ __forceinline__ uint64_t make_umma_desc(uint32_t smem_addr, int block_mn) { - uint64_t desc = 0; - desc |= desc_encode(smem_addr) & 0x3FFF; - desc |= (desc_encode((uint64_t)block_mn * 16) & 0x3FFF) << 16; - desc |= (desc_encode(128ULL) & 0x3FFF) << 32; - desc |= 1ULL << 46; - return desc; -} -__device__ __forceinline__ uint32_t make_idesc(int block_m, int block_n) { - return (1U << 4) | (1U << 7) | (1U << 10) | - ((uint32_t)(block_n >> 3) << 17) | ((uint32_t)(block_m >> 4) << 24); -} -__device__ void umma_ss(uint32_t tmem_c, uint64_t da, uint64_t db, uint32_t idesc, bool acc) { - uint32_t sc = acc ? 0x3F800000u : 0u; - asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\t" - "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5,%6,%7,%8}, p;\n\t}" - :: "r"(tmem_c), "l"(da), "l"(db), "r"(idesc), "r"(sc), "r"(0), "r"(0), "r"(0), "r"(0)); -} - -// Write (rows, 16) BF16 matrix to SMEM canonical layout -__device__ void write_canonical(bf16_t* dst, const bf16_t* src, int rows, int cols16) { - constexpr int CORES_MN = 16; - for (int i = threadIdx.x; i < 128 * cols16; i += 192) { - int r = i / cols16, c = i % cols16; - if (r < rows && c < cols16) { - int core_mn = r / 8, local_r = r % 8; - int core_k = c / 8, local_c = c % 8; - int dst_idx = core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c; - dst[dst_idx] = src[r * cols16 + c]; - } - } -} - -__global__ void __launch_bounds__(192) -test_mma_4warp_read(float* results) { - const int tid = threadIdx.x; - const int wid = tid / 32; - const int lane = tid % 32; +__global__ void __launch_bounds__(32) +test_16x256b_loads(float* results) { + const int lane = threadIdx.x; const int TMEM_N = 128; extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; - bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 256) + 127) & ~(uintptr_t)127); - bf16_t* sB = sA + 128 * 16; - // TMEM alloc (warp 4) - if (wid == 4) { - uint32_t sp = __cvta_generic_to_shared(sTmemBase); - tmem_alloc(sp, TMEM_N); - } - __syncthreads(); + uint32_t sp = __cvta_generic_to_shared(sTmemBase); + tmem_alloc(sp, TMEM_N); + __syncwarp(); uint32_t tb = *sTmemBase; - // Build A such that A[row, 0] = row+1 (row 0 = 1.0, row 1 = 2.0, etc.) - // This way we can identify which row each lane is reading - if (wid == 5) { // load warp - bf16_t tmp[128 * 16]; - for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f); - for (int r = 0; r < 128; r++) tmp[r * 16 + 0] = f32_to_bf16((float)(r + 1)); - write_canonical(sA, tmp, 128, 16); + // Write data via 32x32b.x8 (known working for stores) + { + float vals[8]; + for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c); + uint32_t ivals[8]; + for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4); - // B: B[col, 0] = 1.0 for col < 128 - for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f); - for (int c = 0; c < 128; c++) tmp[c * 16 + 0] = f32_to_bf16(1.0f); - write_canonical(sB, tmp, 128, 16); + // Write to column groups 0-3 (32 columns) + for (int n = 0; n < 4; n++) { + asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" + :: "r"(tb + n * 8), + "r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]), + "r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7])); + asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory"); + // Update vals for next column group + for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + n * 8 + c); + for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4); + } } - __syncthreads(); - // MMA: A(128,16) × B(128,16) → S(128,128) in TMEM - // S[row, col] should be (row+1) for col < 128 - if (wid == 4) { - uint32_t idesc = make_idesc(128, 128); - uint64_t da = make_umma_desc(__cvta_generic_to_shared(sA), 128); - uint64_t db = make_umma_desc(__cvta_generic_to_shared(sB), 128); - if (tid == 128) umma_ss(tb, da, db, idesc, false); - asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); - } - __syncthreads(); - // Extra fence for TMEM visibility - asm volatile("fence.sc.gpu;" ::: "memory"); - __syncthreads(); + // Now try reading with 16x256b.x1 loads + // 16x256b.x1: lane l reads 4 FP32 values (rows l*4+0..3) from 1 column + int load_count = 0; + int pass = 1; - // 4 warps read TMEM with 32x32b.x8 - // Read column group 0 (columns 0-7) - if (wid < 4) { - float tmp[8]; - asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" - : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), - "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) - : "r"(tb + 0)); + // Read column 0 — lane 0 should get rows 0-3, lane 1 should get rows 4-7, etc. + { + float v0, v1, v2, v3; + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4, %5];" + : "=f"(v0), "=f"(v1), "=f"(v2), "=f"(v3) + : "r"(tb), "r"(0)); // column 0 asm volatile("tcgen05.wait::ld.sync.aligned;"); + load_count++; - // Lane 0: store col 0 value (should be row_id+1) - if (lane == 0) results[wid] = tmp[0]; - // Lane 1 - if (lane == 1) results[4 + wid] = tmp[0]; - // Lane 15 - if (lane == 15) results[8 + wid] = tmp[0]; - // Lane 31 - if (lane == 31) results[12 + wid] = tmp[0]; - - // Lane 0, also store col 1 value - if (lane == 0) results[16 + wid] = tmp[1]; + // Lane l should see values for rows l*4+0..3 + // From the store: row 0 col 0 = 0.0, row 1 col 0 = 10.0, row 2 col 0 = 20.0, ... + // Wait — the 32x32b.x8 store wrote lane l's data to "row l" in TMEM. + // So row l, col 0 = l*10 + 0 = l*10 + // Lane 0 reads rows 0-3: v0=row0=0, v1=row1=10, v2=row2=20, v3=row3=30 + if (lane == 0) { + results[0] = v0; + results[1] = v1; + results[2] = v2; + results[3] = v3; + } + if (lane == 1) { + results[4] = v0; + results[5] = v1; + results[6] = v2; + results[7] = v3; + } } - __syncthreads(); - if (wid == 4) tmem_dealloc(tb, TMEM_N); + // Read column 1 (2nd 16x256b.x1 load — does it crash?) + { + float v0, v1, v2, v3; + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4, %5];" + : "=f"(v0), "=f"(v1), "=f"(v2), "=f"(v3) + : "r"(tb), "r"(1)); // column 1 + asm volatile("tcgen05.wait::ld.sync.aligned;"); + load_count++; + + if (lane == 0) { + results[8] = v0; + results[9] = v1; + results[10] = v2; + results[11] = v3; + } + } + + // Read column 8 (8th column — more 16x256b.x1 loads) + { + float v0, v1, v2, v3; + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4, %5];" + : "=f"(v0), "=f"(v1), "=f"(v2), "=f"(v3) + : "r"(tb), "r"(8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + load_count++; + + if (lane == 0) { + results[12] = v0; // row 0, col 8 = 0*10+8 = 8.0 + results[13] = v1; // row 1, col 8 = 10+8 = 18.0? No, 1*10+0=10 + 8 = 18 + results[14] = v2; + results[15] = v3; + } + } + + // Store load count (if we get here, loads didn't crash) + if (lane == 0) results[16] = (float)load_count; + + tmem_dealloc(tb, TMEM_N); } int main() { - printf("MMA → 4-warp TMEM read test\n"); - printf("============================\n"); - printf("After UMMA, S[row, col] = row+1 for col < 128.\n"); - printf("If 4 warps see different rows, lane l in warp w should see row w*32+l.\n\n"); + printf("16x256b.x1 multiple LOAD test\n"); + printf("==============================\n\n"); float* d_r; cudaMalloc(&d_r, 32 * sizeof(float)); cudaMemset(d_r, 0, 32 * sizeof(float)); - int smem = 256 + 128 + 128*16*2*2 + 256; // sbuf + align + sA + sB + slack - test_mma_4warp_read<<<1, 192, smem>>>(d_r); + test_16x256b_loads<<<1, 32, 256>>>(d_r); cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); + printf("16x256b.x1 loads CRASHED (likely on 2nd or 3rd call)\n"); cudaFree(d_r); return 1; } @@ -166,22 +153,22 @@ int main() { float h[32]; cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost); - printf("Lane 0, col 0 (expect row_id+1): "); - printf("w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[0], h[1], h[2], h[3]); - printf(" If w0=1.0 w1=33.0 w2=65.0 w3=97.0 → each warp sees different 32 rows ✓\n"); - printf(" If all same → all warps see same 32 rows\n\n"); + printf("Load count: %d (3 loads completed = no crash)\n\n", (int)h[16]); - printf("Lane 1, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[4], h[5], h[6], h[7]); - printf("Lane 15, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[8], h[9], h[10], h[11]); - printf("Lane 31, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[12], h[13], h[14], h[15]); - printf("Lane 0, col 1: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[16], h[17], h[18], h[19]); + printf("Column 0, lane 0 (expect rows 0-3 = 0,10,20,30): %.1f %.1f %.1f %.1f\n", + h[0], h[1], h[2], h[3]); + printf("Column 0, lane 1 (expect rows 4-7 = 40,50,60,70): %.1f %.1f %.1f %.1f\n", + h[4], h[5], h[6], h[7]); + printf("Column 1, lane 0 (expect rows 0-3 = 1,11,21,31): %.1f %.1f %.1f %.1f\n", + h[8], h[9], h[10], h[11]); + printf("Column 8, lane 0 (expect row 0 col 8 = 8, row 1 col 8 = 18, etc): %.1f %.1f %.1f %.1f\n", + h[12], h[13], h[14], h[15]); - // Check if warps see different rows - int diff = (h[0] != h[1]) || (h[1] != h[2]) || (h[2] != h[3]); - printf("\nConclusion: %s\n", diff ? - "Warps see DIFFERENT rows — multi-warp softmax works!" : - "Warps see SAME rows — need alternative approach"); + // Verify: column 0, lane 0 should give [0, 10, 20, 30] + int pass = (fabsf(h[0] - 0.0f) < 0.01f) && (fabsf(h[1] - 10.0f) < 0.01f) && + (fabsf(h[2] - 20.0f) < 0.01f) && (fabsf(h[3] - 30.0f) < 0.01f); + printf("\nResult: %s\n", pass ? "16x256b.x1 LOADS work multiple times!" : "Data mismatch"); cudaFree(d_r); - return 0; + return pass ? 0 : 1; }