Files
nvfp4-megamoe-kernel/tests/unit/test_mma_ts.cu

133 lines
5.0 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Minimal tcgen05.mma TS test — P (TMEM) × V (SMEM) → O (TMEM)
*
* Test: A = all 1.0 in TMEM (128, 16), B = all 1.0 in SMEM (16, 16)
* Expected C = all 16.0 in TMEM (128, 16)
*
* This isolates the PV GEMM to debug the "illegal memory access" crash.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
constexpr int BLOCK_MN = 128;
__global__ void __launch_bounds__(128)
test_mma_ts()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
// SMEM: tmem_base + V (16, 16) canonical
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
// Load V = all 1.0 into (16, 16) canonical
// (16, 16): CORES_MN=2, CORES_K=2
for (int i = tid; i < 16 * 16; i += 128) sV[i] = 0;
__syncthreads();
for (int i = tid; i < 16 * 16; i += 128) {
int r = i / 16, c = i % 16;
int ck = c / 8, lc = c % 8;
int tmn = r / 8, lr = r % 8;
sV[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(2.0f); // Use 2.0 to distinguish from A=1.0
}
__syncthreads();
// TMEM alloc — 64 columns (16 for A at offset 0, 16 for C at offset 32)
// TMEM alloc requires power of 2, minimum 32
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64);
__syncthreads();
uint32_t tb = *sTmemBase;
uint32_t tb_a = tb; // A starts at column 0
uint32_t tb_c = tb + 32; // C starts at column 32
// Write A = non-uniform values to TMEM cols 0-15 using 32x32b.x8
// Each lane i writes value (i+1.0) — different values per lane to test layout
if (wid == 0) {
for (int n = 0; n < 16 / 8; n++) {
float p0=1.0f, p1=2.0f, p2=3.0f, p3=4.0f;
float p4=5.0f, p5=6.0f, p6=7.0f, p7=8.0f;
// All lanes write the same values (uniform across rows)
// but different across columns
p0=(n==0)?1.0f:9.0f; p1=(n==0)?2.0f:10.0f; p2=(n==0)?3.0f:11.0f; p3=(n==0)?4.0f:12.0f;
p4=(n==0)?5.0f:13.0f; p5=(n==0)?6.0f:14.0f; p6=(n==0)?7.0f:15.0f; p7=(n==0)?8.0f:16.0f;
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb_a+n*8),"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
}
tmem_fence_store();
}
__syncthreads();
// Read back A
if (wid == 0) {
for (int n = 0; n < 16 / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb_a+n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) printf("A[0,%d..%d] = %.1f %.1f %.1f %.1f\n", n*8, n*8+7, tmp[0], tmp[1], tmp[2], tmp[3]);
}
}
__syncthreads();
// tcgen05.mma TS: A (TMEM) × B (SMEM) → C (TMEM)
// A is at tb (columns 0-15)
// B is at sV (16, 16)
// C goes to tb (same location — will overwrite A)
// idesc: M=128, N=16 → MMA_M=8, MMA_N=2
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(BLOCK_MN, 16);
printf("Before MMA: tb=%u, dv=%lu, idesc=%u, tid=%d\n", tb, dv, idesc, tid);
if (tid == 0) {
umma_ts_f16(tb_c, tb_a, dv, idesc, false);
}
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
printf("After MMA: tid=%d\n", tid);
// Read C from TMEM
if (wid == 0) {
float c_vals[16];
for (int n = 0; n < 2; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb_c + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) c_vals[n*8+c] = tmp[c];
}
if (lane == 0) {
printf("C[0,0..7] (row 0, lane 0): ");
for (int c=0;c<8;c++) printf("%.2f ", c_vals[c]);
printf("\n");
// Expected: 136.0 (sum(1..16) * 2.0 * 0.5 MMA scale)
float max_err = 0.0f;
for (int c=0;c<16;c++) max_err = fmaxf(max_err, fabsf(c_vals[c] - 136.0f));
printf("Max err from 136.0: %.6f\n", max_err);
}
}
if (wid == 0) tmem_dealloc(tb, 64);
}
int main() {
printf("=== Minimal tcgen05.mma TS Test ===\n");
int smem = (4 + 16 + 16*16*2 + 256 + 127) & ~127;
test_mma_ts<<<1, 128, smem>>>();
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
printf("Kernel completed!\n");
return 0;
}