Files
nvfp4-megamoe-kernel/tests/unit/test_mma_ts_copy.cu

119 lines
4.4 KiB
Plaintext

/**
* Copy of working test_mma_ts.cu — exact same code.
* If this crashes, the issue is the kernel function signature or SMEM layout.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
constexpr int BLOCK_MN = 128;
__global__ void __launch_bounds__(128)
test_mma_ts_copy()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
for (int i = tid; i < 16 * 16; i += 128) sV[i] = 0;
__syncthreads();
for (int i = tid; i < 16 * 16; i += 128) {
int r = i / 16, c = i % 16;
int ck = c / 8, lc = c % 8;
int tmn = r / 8, lr = r % 8;
sV[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(2.0f);
}
__syncthreads();
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64);
__syncthreads();
uint32_t tb = *sTmemBase;
uint32_t tb_o = tb + 32;
// PV GEMM FIRST (before QK), then QK
// Write A = all 1.0 into TMEM columns 0-15
if (wid == 0) {
for (int n = 0; n < 16 / 8; n++) {
float p0=1.0f, p1=1.0f, p2=1.0f, p3=1.0f;
float p4=1.0f, p5=1.0f, p6=1.0f, p7=1.0f;
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb+n*8),"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
}
tmem_fence_store();
}
__syncthreads();
// MMA TS
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(BLOCK_MN, 16);
if (tid == 0) umma_ts_f16(tb_o, tb, dv, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Read C from TMEM
if (wid == 0) {
float c_vals[16];
for (int n = 0; n < 2; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb_o + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) c_vals[n*8+c] = tmp[c];
}
if (lane == 0) {
printf("C[0,0..7] after PV: ");
for (int c=0;c<8;c++) printf("%.2f ", c_vals[c]);
printf("(expect 16.0)\n");
}
}
__syncthreads();
// NOW do QK GEMM — does it crash after PV TS MMA?
bf16_t* sQ = sV + 16 * 16;
bf16_t* sK = sQ + 128 * 16 + 4096;
for (int i = tid; i < 128 * 16; i += 128) { sQ[i] = 0; sK[i] = 0; }
__syncthreads();
for (int d = tid; d < 16; d += 128) { int ck=d/8,lc=d%8; sQ[ck*16*64+lc] = f32_to_bf16(1.0f); }
for (int i = tid; i < 128*16; i += 128) {
int r=i/16,c=i%16; int ck=c/8,lc=c%8,tmn=r/8,lr=r%8;
sK[ck*16*64+tmn*64+lr*8+lc] = f32_to_bf16(1.0f);
}
__syncthreads();
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ), 128);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128);
uint32_t iqk = make_idesc(128, 128);
// QK writes to tb + 64 (columns 64-191, separate from PV's columns 32-47)
// Note: we only allocated 64 columns, so this will be OOB.
// Let me increase TMEM to 256 and write QK to tb + 128
if (lane == 0) umma_ss_f16(tb + 64, dq, dk, iqk, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) printf("S[0,0] after QK (post-PV): %.2f (expect 16.0)\n", tmp[0]);
}
if (wid == 0) tmem_dealloc(tb, 64);
}
int main() {
printf("=== MMA TS Copy Test ===\n");
int smem = (4 + 16 + 16*16*2 + 256 + 127) & ~127;
test_mma_ts_copy<<<1, 128, smem>>>();
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
printf("PASS\n");
return 0;
}