test: minimal tcgen05.mma TS debug (PV GEMM)
This commit is contained in:
130
tests/unit/test_mma_ts.cu
Normal file
130
tests/unit/test_mma_ts.cu
Normal file
@@ -0,0 +1,130 @@
|
||||
/**
|
||||
* Minimal tcgen05.mma TS test — P (TMEM) × V (SMEM) → O (TMEM)
|
||||
*
|
||||
* Test: A = all 1.0 in TMEM (128, 16), B = all 1.0 in SMEM (16, 16)
|
||||
* Expected C = all 16.0 in TMEM (128, 16)
|
||||
*
|
||||
* This isolates the PV GEMM to debug the "illegal memory access" crash.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
constexpr int BLOCK_MN = 128;
|
||||
|
||||
__global__ void __launch_bounds__(128)
|
||||
test_mma_ts(float* o_out)
|
||||
{
|
||||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||||
|
||||
// SMEM: tmem_base + V (16, 16) canonical
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||||
|
||||
// Load V = all 1.0 into (16, 16) canonical
|
||||
// (16, 16): CORES_MN=2, CORES_K=2
|
||||
for (int i = tid; i < 16 * 16; i += 128) sV[i] = 0;
|
||||
__syncthreads();
|
||||
for (int i = tid; i < 16 * 16; i += 128) {
|
||||
int r = i / 16, c = i % 16;
|
||||
int ck = c / 8, lc = c % 8;
|
||||
int tmn = r / 8, lr = r % 8;
|
||||
sV[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(1.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// TMEM alloc — 32 columns (16 for A, 16 for C)
|
||||
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 32);
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Write A = all 1.0 into TMEM columns 0-15 (128 rows × 16 columns)
|
||||
if (wid == 0) {
|
||||
for (int col = 0; col < 16; col++) {
|
||||
// Each column: 128 FP32. Lane i writes positions i*4..i*4+3
|
||||
float v0 = 1.0f, v1 = 1.0f, v2 = 1.0f, v3 = 1.0f;
|
||||
tmem_store(tb + col, f32_to_u32(v0), f32_to_u32(v1), f32_to_u32(v2), f32_to_u32(v3));
|
||||
}
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Read back A to verify it was written correctly
|
||||
if (wid == 0) {
|
||||
float check = 0.0f;
|
||||
for (int col = 0; col < 16; col++) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tb + col, u0, u1, u2, u3);
|
||||
tmem_fence_load();
|
||||
check += u32_to_f32(u0);
|
||||
}
|
||||
if (lane == 0) printf("A sum (lane 0, col 0, pos 0..3): %.1f (expect 16.0)\n", check);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// tcgen05.mma TS: A (TMEM) × B (SMEM) → C (TMEM)
|
||||
// A is at tb (columns 0-15)
|
||||
// B is at sV (16, 16)
|
||||
// C goes to tb (same location — will overwrite A)
|
||||
// idesc: M=128, N=16 → MMA_M=8, MMA_N=2
|
||||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||||
uint32_t idesc = make_idesc(BLOCK_MN, 16);
|
||||
|
||||
printf("Before MMA: tb=%u, dv=%lu, idesc=%u, tid=%d\n", tb, dv, idesc, tid);
|
||||
|
||||
if (tid == 0) {
|
||||
umma_ts_f16(tb, tb, dv, idesc, false);
|
||||
}
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
printf("After MMA: tid=%d\n", tid);
|
||||
|
||||
// Read C from TMEM
|
||||
if (wid == 0) {
|
||||
float c_vals[16];
|
||||
for (int n = 0; n < 2; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) c_vals[n*8+c] = tmp[c];
|
||||
}
|
||||
if (lane == 0) {
|
||||
printf("C[0,0..7] (row 0, lane 0): ");
|
||||
for (int c=0;c<8;c++) printf("%.2f ", c_vals[c]);
|
||||
printf("\n");
|
||||
// Expected: all 16.0 (1.0 * 1.0 * 16 = 16.0)
|
||||
float max_err = 0.0f;
|
||||
for (int c=0;c<16;c++) max_err = fmaxf(max_err, fabsf(c_vals[c] - 16.0f));
|
||||
printf("Max err from 16.0: %.6f\n", max_err);
|
||||
}
|
||||
}
|
||||
|
||||
if (wid == 0) tmem_dealloc(tb, 32);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("=== Minimal tcgen05.mma TS Test ===\n");
|
||||
|
||||
float* d_out;
|
||||
cudaMalloc(&d_out, 16 * sizeof(float));
|
||||
|
||||
int smem = (4 + 16 + 16*16*2 + 256 + 127) & ~127;
|
||||
test_mma_ts<<<1, 128, smem>>>(d_out);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||||
|
||||
printf("Test completed successfully!\n");
|
||||
cudaFree(d_out);
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user