133 lines
5.0 KiB
Plaintext
133 lines
5.0 KiB
Plaintext
/**
|
||
* Minimal tcgen05.mma TS test — P (TMEM) × V (SMEM) → O (TMEM)
|
||
*
|
||
* Test: A = all 1.0 in TMEM (128, 16), B = all 1.0 in SMEM (16, 16)
|
||
* Expected C = all 16.0 in TMEM (128, 16)
|
||
*
|
||
* This isolates the PV GEMM to debug the "illegal memory access" crash.
|
||
*/
|
||
|
||
#include <cuda_runtime.h>
|
||
#include <cstdio>
|
||
#include <cmath>
|
||
#include <cstdlib>
|
||
#include <cstring>
|
||
|
||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||
|
||
using namespace dsv4::kernels::attention;
|
||
|
||
constexpr int BLOCK_MN = 128;
|
||
|
||
__global__ void __launch_bounds__(128)
|
||
test_mma_ts()
|
||
{
|
||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||
|
||
// SMEM: tmem_base + V (16, 16) canonical
|
||
extern __shared__ char sbuf[];
|
||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||
|
||
// Load V = all 1.0 into (16, 16) canonical
|
||
// (16, 16): CORES_MN=2, CORES_K=2
|
||
for (int i = tid; i < 16 * 16; i += 128) sV[i] = 0;
|
||
__syncthreads();
|
||
for (int i = tid; i < 16 * 16; i += 128) {
|
||
int r = i / 16, c = i % 16;
|
||
int ck = c / 8, lc = c % 8;
|
||
int tmn = r / 8, lr = r % 8;
|
||
sV[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(2.0f); // Use 2.0 to distinguish from A=1.0
|
||
}
|
||
__syncthreads();
|
||
|
||
// TMEM alloc — 64 columns (16 for A at offset 0, 16 for C at offset 32)
|
||
// TMEM alloc requires power of 2, minimum 32
|
||
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64);
|
||
__syncthreads();
|
||
uint32_t tb = *sTmemBase;
|
||
uint32_t tb_a = tb; // A starts at column 0
|
||
uint32_t tb_c = tb + 32; // C starts at column 32
|
||
|
||
// Write A = non-uniform values to TMEM cols 0-15 using 32x32b.x8
|
||
// Each lane i writes value (i+1.0) — different values per lane to test layout
|
||
if (wid == 0) {
|
||
for (int n = 0; n < 16 / 8; n++) {
|
||
float p0=1.0f, p1=2.0f, p2=3.0f, p3=4.0f;
|
||
float p4=5.0f, p5=6.0f, p6=7.0f, p7=8.0f;
|
||
// All lanes write the same values (uniform across rows)
|
||
// but different across columns
|
||
p0=(n==0)?1.0f:9.0f; p1=(n==0)?2.0f:10.0f; p2=(n==0)?3.0f:11.0f; p3=(n==0)?4.0f:12.0f;
|
||
p4=(n==0)?5.0f:13.0f; p5=(n==0)?6.0f:14.0f; p6=(n==0)?7.0f:15.0f; p7=(n==0)?8.0f:16.0f;
|
||
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb_a+n*8),"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
|
||
}
|
||
tmem_fence_store();
|
||
}
|
||
__syncthreads();
|
||
|
||
// Read back A
|
||
if (wid == 0) {
|
||
for (int n = 0; n < 16 / 8; n++) {
|
||
float tmp[8];
|
||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb_a+n*8));
|
||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||
if (lane == 0) printf("A[0,%d..%d] = %.1f %.1f %.1f %.1f\n", n*8, n*8+7, tmp[0], tmp[1], tmp[2], tmp[3]);
|
||
}
|
||
}
|
||
__syncthreads();
|
||
|
||
// tcgen05.mma TS: A (TMEM) × B (SMEM) → C (TMEM)
|
||
// A is at tb (columns 0-15)
|
||
// B is at sV (16, 16)
|
||
// C goes to tb (same location — will overwrite A)
|
||
// idesc: M=128, N=16 → MMA_M=8, MMA_N=2
|
||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||
uint32_t idesc = make_idesc(BLOCK_MN, 16);
|
||
|
||
printf("Before MMA: tb=%u, dv=%lu, idesc=%u, tid=%d\n", tb, dv, idesc, tid);
|
||
|
||
if (tid == 0) {
|
||
umma_ts_f16(tb_c, tb_a, dv, idesc, false);
|
||
}
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
printf("After MMA: tid=%d\n", tid);
|
||
|
||
// Read C from TMEM
|
||
if (wid == 0) {
|
||
float c_vals[16];
|
||
for (int n = 0; n < 2; n++) {
|
||
float tmp[8];
|
||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb_c + n*8));
|
||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||
if (lane == 0) for (int c=0;c<8;c++) c_vals[n*8+c] = tmp[c];
|
||
}
|
||
if (lane == 0) {
|
||
printf("C[0,0..7] (row 0, lane 0): ");
|
||
for (int c=0;c<8;c++) printf("%.2f ", c_vals[c]);
|
||
printf("\n");
|
||
// Expected: 136.0 (sum(1..16) * 2.0 * 0.5 MMA scale)
|
||
float max_err = 0.0f;
|
||
for (int c=0;c<16;c++) max_err = fmaxf(max_err, fabsf(c_vals[c] - 136.0f));
|
||
printf("Max err from 136.0: %.6f\n", max_err);
|
||
}
|
||
}
|
||
|
||
if (wid == 0) tmem_dealloc(tb, 64);
|
||
}
|
||
|
||
int main() {
|
||
printf("=== Minimal tcgen05.mma TS Test ===\n");
|
||
|
||
int smem = (4 + 16 + 16*16*2 + 256 + 127) & ~127;
|
||
test_mma_ts<<<1, 128, smem>>>();
|
||
|
||
cudaError_t err = cudaDeviceSynchronize();
|
||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||
|
||
printf("Kernel completed!\n");
|
||
return 0;
|
||
}
|