test: minimal tcgen05.mma TS debug (PV GEMM)

This commit is contained in:
2026-05-28 13:31:18 +00:00
parent efa03f53d4
commit 37a502e476

130
tests/unit/test_mma_ts.cu Normal file
View File

@@ -0,0 +1,130 @@
/**
* Minimal tcgen05.mma TS test — P (TMEM) × V (SMEM) → O (TMEM)
*
* Test: A = all 1.0 in TMEM (128, 16), B = all 1.0 in SMEM (16, 16)
* Expected C = all 16.0 in TMEM (128, 16)
*
* This isolates the PV GEMM to debug the "illegal memory access" crash.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
constexpr int BLOCK_MN = 128;
__global__ void __launch_bounds__(128)
test_mma_ts(float* o_out)
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
// SMEM: tmem_base + V (16, 16) canonical
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
// Load V = all 1.0 into (16, 16) canonical
// (16, 16): CORES_MN=2, CORES_K=2
for (int i = tid; i < 16 * 16; i += 128) sV[i] = 0;
__syncthreads();
for (int i = tid; i < 16 * 16; i += 128) {
int r = i / 16, c = i % 16;
int ck = c / 8, lc = c % 8;
int tmn = r / 8, lr = r % 8;
sV[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(1.0f);
}
__syncthreads();
// TMEM alloc — 32 columns (16 for A, 16 for C)
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 32);
__syncthreads();
uint32_t tb = *sTmemBase;
// Write A = all 1.0 into TMEM columns 0-15 (128 rows × 16 columns)
if (wid == 0) {
for (int col = 0; col < 16; col++) {
// Each column: 128 FP32. Lane i writes positions i*4..i*4+3
float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f;
tmem_store(tb + col, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
}
tmem_fence_store();
}
__syncthreads();
// Read back A to verify it was written correctly
if (wid == 0) {
float check = 0.0f;
for (int col = 0; col < 16; col++) {
uint32_t u0, u1, u2, u3;
tmem_load(tb + col, u0, u1, u2, u3);
tmem_fence_load();
check += u32_to_f32(u0);
}
if (lane == 0) printf("A sum (lane 0, col 0, pos 0..3): %.1f (expect 16.0)\n", check);
}
__syncthreads();
// tcgen05.mma TS: A (TMEM) × B (SMEM) → C (TMEM)
// A is at tb (columns 0-15)
// B is at sV (16, 16)
// C goes to tb (same location — will overwrite A)
// idesc: M=128, N=16 → MMA_M=8, MMA_N=2
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(BLOCK_MN, 16);
printf("Before MMA: tb=%u, dv=%lu, idesc=%u, tid=%d\n", tb, dv, idesc, tid);
if (tid == 0) {
umma_ts_f16(tb, tb, dv, idesc, false);
}
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
printf("After MMA: tid=%d\n", tid);
// Read C from TMEM
if (wid == 0) {
float c_vals[16];
for (int n = 0; n < 2; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) c_vals[n*8+c] = tmp[c];
}
if (lane == 0) {
printf("C[0,0..7] (row 0, lane 0): ");
for (int c=0;c<8;c++) printf("%.2f ", c_vals[c]);
printf("\n");
// Expected: all 16.0 (1.0 * 1.0 * 16 = 16.0)
float max_err = 0.0f;
for (int c=0;c<16;c++) max_err = fmaxf(max_err, fabsf(c_vals[c] - 16.0f));
printf("Max err from 16.0: %.6f\n", max_err);
}
}
if (wid == 0) tmem_dealloc(tb, 32);
}
int main() {
printf("=== Minimal tcgen05.mma TS Test ===\n");
float* d_out;
cudaMalloc(&d_out, 16 * sizeof(float));
int smem = (4 + 16 + 16*16*2 + 256 + 127) & ~127;
test_mma_ts<<<1, 128, smem>>>(d_out);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
printf("Test completed successfully!\n");
cudaFree(d_out);
return 0;
}