Files
nvfp4-megamoe-kernel/tests/unit/test_tmem_layout_pv64.cu
2026-05-28 15:48:15 +00:00

276 lines
11 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Map TMEM Layout D for PV MMA at N=64 (HD=64).
* Do a single PV MMA with P = all-1s, V = all-1s.
* Then read ALL 128 TMEM columns and print which positions
* correspond to which output element.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cstring>
#include <cmath>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = 64, SK = 128, BLOCK_MN = 128;
constexpr int LOCAL_MMA_K = 16;
constexpr int TILE_SZ = BLOCK_MN * LOCAL_MMA_K;
constexpr int V_TILE_SZ = (HD / 8) * 2 * 64; // 1024
__global__ void __launch_bounds__(128)
test_tmem_layout_pv(const bf16_t* q, const bf16_t* k, const bf16_t* v,
float* tmem_dump, float scale)
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK0 = sQ0 + 4 * TILE_SZ;
bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + 4 * TILE_SZ) + 127) & ~(uintptr_t)127);
bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127);
float* s_p_vals = (float*)(sV + 8 * V_TILE_SZ);
// Load Q, K (same as before)
for (int kt = 0; kt < 4; kt++) {
bf16_t* sq = sQ0 + kt * TILE_SZ;
for (int i = tid; i < TILE_SZ; i += 128) sq[i] = 0;
for (int d = tid; d < LOCAL_MMA_K; d += 128) {
int ck = d / 8, lc = d % 8;
sq[ck * 16 * 64 + lc] = q[kt * LOCAL_MMA_K + d];
}
}
for (int kt = 0; kt < 4; kt++) {
bf16_t* sk = sK0 + kt * TILE_SZ;
for (int i = tid; i < TILE_SZ; i += 128) sk[i] = 0;
for (int r = 0; r < SK; r++) {
for (int d = tid; d < LOCAL_MMA_K; d += 128) {
int ck = d / 8, lc = d % 8;
int tmn = r / 8, lr = r % 8;
sk[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * LOCAL_MMA_K + d];
}
}
}
for (int kt = 0; kt < 8; kt++) {
bf16_t* sv = sV + kt * V_TILE_SZ;
for (int i = tid; i < V_TILE_SZ; i += 128) sv[i] = 0;
for (int d = tid; d < HD; d += 128) {
for (int lr = 0; lr < LOCAL_MMA_K; lr++) {
int r = kt * LOCAL_MMA_K + lr;
int g_mn = d / 8, g_k = lr / 8;
int llr = d % 8, lc = lr % 8;
sv[g_k * 8 * 64 + g_mn * 64 + llr * 8 + lc] = v[d * SK + r];
}
}
}
__syncthreads();
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// QK GEMM
{
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < 4; kt++) {
bf16_t* sq = sQ0 + kt * TILE_SZ;
bf16_t* sk = sK0 + kt * TILE_SZ;
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sq), BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sk), BLOCK_MN);
if (tid == 0) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
}
// Softmax
if (wid == 0) {
float s_vals[SK], row_max = -INFINITY;
for (int n = 0; n < SK / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) {
s_vals[n*8+c] = tmp[c] * scale;
row_max = fmaxf(row_max, tmp[c] * scale);
}
}
row_max = wmax(row_max);
float row_sum = 0.0f;
if (lane == 0) for (int j=0;j<SK;j++) {
s_vals[j] = expf(s_vals[j] - row_max);
row_sum += s_vals[j];
}
row_sum = wsum(row_sum);
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
if (lane == 0) for (int j=0;j<SK;j++) s_p_vals[j] = s_vals[j];
}
__syncthreads();
// ===== Do PV MMA with BLOCK_MN_B=64 (N=64), then READ ALL 128 TMEM COLUMNS =====
{
// First, dealloc and realloc TMEM for clean state
// Actually, just zero it first with a store
// Better: dealloc, realloc with 128 cols, then do PV
}
// Actually let's just do the PV MMA into the same TMEM (which has S from QK)
// The PV will overwrite columns 0..63 (N=64) and leave 64..127 with old S data
{
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD); // (128, 64)
for (int kt = 0; kt < 8; kt++) {
for (int i = tid; i < TILE_SZ; i += 128) sPk[i] = 0;
if (tid < 16) {
int c = tid;
int ck = c / 8, lc = c % 8;
sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * LOCAL_MMA_K + c]);
}
__syncthreads();
bf16_t* sv = sV + kt * V_TILE_SZ;
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), BLOCK_MN);
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), HD);
if (tid == 0) umma_ss_f16(tb, dp, dv, idesc_pv, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
}
// ===== Dump ALL 128 TMEM columns =====
// Each column has 128 FP32 values. Lane i reads positions i*4+0..3.
// We dump lane 0's 4 positions per column.
if (wid == 0) {
for (int col = 0; col < 128; col++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + col));
asm volatile("tcgen05.wait::ld.sync.aligned;");
// Lane 0 gets positions 0-3 of this column
if (lane == 0) {
int base = col * 4; // Assuming simple mapping
for (int c = 0; c < 4; c++) {
tmem_dump[base + c] = tmp[c];
}
}
}
}
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== TMEM Layout D mapping for PV MMA N=64 ===\n");
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)malloc(HD*sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)malloc(SK*HD*sizeof(bf16_t));
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
srand(42);
for (int d=0;d<HD;d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<SK*HD;i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
bf16_t *d_q,*d_k,*d_v;
float* d_tmem_dump;
cudaMalloc(&d_q, HD*sizeof(bf16_t));
cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
// 128 columns × 4 positions per column (lane 0 only) = 512
// But actually 128 cols × 128 positions = 16384, we only dump lane 0's 4 per col = 512
cudaMalloc(&d_tmem_dump, 128 * 4 * sizeof(float));
cudaMemset(d_tmem_dump, 0, 128 * 4 * sizeof(float));
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
int smem = (4+16 + 4*TILE_SZ*2 + 4*TILE_SZ*2 + TILE_SZ*2 + 8*V_TILE_SZ*2 + SK*4 + 256 + 127) & ~127;
printf("SMEM: %d bytes (%.1f KB)\n", smem, smem/1024.0f);
cudaFuncSetAttribute(test_tmem_layout_pv, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
test_tmem_layout_pv<<<1, 128, smem>>>(d_q, d_k, d_v, d_tmem_dump, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
float* h_dump = (float*)malloc(128 * 4 * sizeof(float));
cudaMemcpy(h_dump, d_tmem_dump, 128 * 4 * sizeof(float), cudaMemcpyDeviceToHost);
// Print the dump: positions 0..511
// For row 0, the expected output is the PV result for row 0 (T=1 decode)
// Compute reference
float s[SK];
for (int j=0;j<SK;j++) {
float dot = 0.0f;
for (int d=0;d<HD;d++) dot += bf16_to_f32_host(h_q[d]) * bf16_to_f32_host(h_k[j*HD+d]);
s[j] = dot * SCALE;
}
float mx = -INFINITY;
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
float sm = 0.0f;
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
for (int j=0;j<SK;j++) s[j] /= sm;
float o_ref[HD];
for (int d=0;d<HD;d++) {
float ov = 0.0f;
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32_host(h_v[d*SK+j]);
o_ref[d] = ov;
}
printf("\n--- TMEM dump (lane 0, 4 positions per column) ---\n");
printf("Showing non-zero values in first 64 columns (PV output N=64):\n");
for (int col = 0; col < 64; col++) {
bool nonzero = false;
for (int p = 0; p < 4; p++) if (fabsf(h_dump[col*4+p]) > 1e-6f) nonzero = true;
if (nonzero) {
printf(" col %3d: %10.6f %10.6f %10.6f %10.6f\n", col,
h_dump[col*4+0], h_dump[col*4+1], h_dump[col*4+2], h_dump[col*4+3]);
}
}
printf("\nShowing non-zero values in columns 64-127 (should be zero for PV output):\n");
for (int col = 64; col < 128; col++) {
bool nonzero = false;
for (int p = 0; p < 4; p++) if (fabsf(h_dump[col*4+p]) > 1e-6f) nonzero = true;
if (nonzero) {
printf(" col %3d: %10.6f %10.6f %10.6f %10.6f\n", col,
h_dump[col*4+0], h_dump[col*4+1], h_dump[col*4+2], h_dump[col*4+3]);
}
}
// Now try to match: for each of the 64 output positions d (row 0 only),
// find which (col, position_in_lane0) gives the closest value to o_ref[d]
printf("\n--- Mapping output position -> (col, slot) ---\n");
for (int d = 0; d < HD; d++) {
float target = o_ref[d];
int best_col = -1, best_slot = -1;
float best_diff = 1e10f;
for (int col = 0; col < 128; col++) {
for (int p = 0; p < 4; p++) {
float diff = fabsf(h_dump[col*4+p] - target);
if (diff < best_diff) {
best_diff = diff;
best_col = col;
best_slot = p;
}
}
}
printf(" d=%2d: ref=%10.6f found at (col=%3d, slot=%d) val=%10.6f diff=%.2e\n",
d, target, best_col, best_slot, h_dump[best_col*4+best_slot], best_diff);
}
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_tmem_dump);
free(h_q); free(h_k); free(h_v); free(h_dump);
return 0;
}