diff --git a/tests/unit/test_mma_ts.cu b/tests/unit/test_mma_ts.cu new file mode 100644 index 00000000..daf1ef4c --- /dev/null +++ b/tests/unit/test_mma_ts.cu @@ -0,0 +1,130 @@ +/** + * Minimal tcgen05.mma TS test — P (TMEM) × V (SMEM) → O (TMEM) + * + * Test: A = all 1.0 in TMEM (128, 16), B = all 1.0 in SMEM (16, 16) + * Expected C = all 16.0 in TMEM (128, 16) + * + * This isolates the PV GEMM to debug the "illegal memory access" crash. + */ + +#include +#include +#include +#include +#include + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" + +using namespace dsv4::kernels::attention; + +constexpr int BLOCK_MN = 128; + +__global__ void __launch_bounds__(128) +test_mma_ts(float* o_out) +{ + const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; + + // SMEM: tmem_base + V (16, 16) canonical + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); + + // Load V = all 1.0 into (16, 16) canonical + // (16, 16): CORES_MN=2, CORES_K=2 + for (int i = tid; i < 16 * 16; i += 128) sV[i] = 0; + __syncthreads(); + for (int i = tid; i < 16 * 16; i += 128) { + int r = i / 16, c = i % 16; + int ck = c / 8, lc = c % 8; + int tmn = r / 8, lr = r % 8; + sV[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(1.0f); + } + __syncthreads(); + + // TMEM alloc — 32 columns (16 for A, 16 for C) + if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 32); + __syncthreads(); + uint32_t tb = *sTmemBase; + + // Write A = all 1.0 into TMEM columns 0-15 (128 rows × 16 columns) + if (wid == 0) { + for (int col = 0; col < 16; col++) { + // Each column: 128 FP32. Lane i writes positions i*4..i*4+3 + float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f; + tmem_store(tb + col, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3)); + } + tmem_fence_store(); + } + __syncthreads(); + + // Read back A to verify it was written correctly + if (wid == 0) { + float check = 0.0f; + for (int col = 0; col < 16; col++) { + uint32_t u0, u1, u2, u3; + tmem_load(tb + col, u0, u1, u2, u3); + tmem_fence_load(); + check += u32_to_f32(u0); + } + if (lane == 0) printf("A sum (lane 0, col 0, pos 0..3): %.1f (expect 16.0)\n", check); + } + __syncthreads(); + + // tcgen05.mma TS: A (TMEM) × B (SMEM) → C (TMEM) + // A is at tb (columns 0-15) + // B is at sV (16, 16) + // C goes to tb (same location — will overwrite A) + // idesc: M=128, N=16 → MMA_M=8, MMA_N=2 + uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); + uint32_t idesc = make_idesc(BLOCK_MN, 16); + + printf("Before MMA: tb=%u, dv=%lu, idesc=%u, tid=%d\n", tb, dv, idesc, tid); + + if (tid == 0) { + umma_ts_f16(tb, tb, dv, idesc, false); + } + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + + printf("After MMA: tid=%d\n", tid); + + // Read C from TMEM + if (wid == 0) { + float c_vals[16]; + for (int n = 0; n < 2; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) c_vals[n*8+c] = tmp[c]; + } + if (lane == 0) { + printf("C[0,0..7] (row 0, lane 0): "); + for (int c=0;c<8;c++) printf("%.2f ", c_vals[c]); + printf("\n"); + // Expected: all 16.0 (1.0 * 1.0 * 16 = 16.0) + float max_err = 0.0f; + for (int c=0;c<16;c++) max_err = fmaxf(max_err, fabsf(c_vals[c] - 16.0f)); + printf("Max err from 16.0: %.6f\n", max_err); + } + } + + if (wid == 0) tmem_dealloc(tb, 32); +} + +int main() { + printf("=== Minimal tcgen05.mma TS Test ===\n"); + + float* d_out; + cudaMalloc(&d_out, 16 * sizeof(float)); + + int smem = (4 + 16 + 16*16*2 + 256 + 127) & ~127; + test_mma_ts<<<1, 128, smem>>>(d_out); + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + printf("Test completed successfully!\n"); + cudaFree(d_out); + return 0; +}