Files
nvfp4-megamoe-kernel/tests/unit/test_tmem_layout_full.cu
2026-05-28 15:49:47 +00:00

250 lines
10 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Map TMEM Layout D for PV MMA N=64 using 16x256b reads.
* Read ALL positions (all lanes) for columns 0-63.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cstring>
#include <cmath>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = 64, SK = 128, BLOCK_MN = 128;
constexpr int LOCAL_MMA_K = 16;
constexpr int TILE_SZ = BLOCK_MN * LOCAL_MMA_K;
constexpr int V_TILE_SZ = (HD / 8) * 2 * 64;
__global__ void __launch_bounds__(128)
test_tmem_full_dump(const bf16_t* q, const bf16_t* k, const bf16_t* v,
float* tmem_dump, float scale)
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sK0 = sQ0 + 4 * TILE_SZ;
bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + 4 * TILE_SZ) + 127) & ~(uintptr_t)127);
bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127);
float* s_p_vals = (float*)(sV + 8 * V_TILE_SZ);
// Load Q, K, V (same as debug test)
for (int kt = 0; kt < 4; kt++) {
bf16_t* sq = sQ0 + kt * TILE_SZ;
for (int i = tid; i < TILE_SZ; i += 128) sq[i] = 0;
for (int d = tid; d < LOCAL_MMA_K; d += 128) {
int ck = d / 8, lc = d % 8;
sq[ck * 16 * 64 + lc] = q[kt * LOCAL_MMA_K + d];
}
}
for (int kt = 0; kt < 4; kt++) {
bf16_t* sk = sK0 + kt * TILE_SZ;
for (int i = tid; i < TILE_SZ; i += 128) sk[i] = 0;
for (int r = 0; r < SK; r++) {
for (int d = tid; d < LOCAL_MMA_K; d += 128) {
int ck = d / 8, lc = d % 8;
int tmn = r / 8, lr = r % 8;
sk[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * LOCAL_MMA_K + d];
}
}
}
for (int kt = 0; kt < 8; kt++) {
bf16_t* sv = sV + kt * V_TILE_SZ;
for (int i = tid; i < V_TILE_SZ; i += 128) sv[i] = 0;
for (int d = tid; d < HD; d += 128) {
for (int lr = 0; lr < LOCAL_MMA_K; lr++) {
int r = kt * LOCAL_MMA_K + lr;
int g_mn = d / 8, g_k = lr / 8;
int llr = d % 8, lc = lr % 8;
sv[g_k * 8 * 64 + g_mn * 64 + llr * 8 + lc] = v[d * SK + r];
}
}
}
__syncthreads();
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// QK GEMM
{
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
for (int kt = 0; kt < 4; kt++) {
bf16_t* sq = sQ0 + kt * TILE_SZ;
bf16_t* sk = sK0 + kt * TILE_SZ;
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sq), BLOCK_MN);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sk), BLOCK_MN);
if (tid == 0) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
}
// Softmax
if (wid == 0) {
float s_vals[SK], row_max = -INFINITY;
for (int n = 0; n < SK / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) {
s_vals[n*8+c] = tmp[c] * scale;
row_max = fmaxf(row_max, tmp[c] * scale);
}
}
row_max = wmax(row_max);
float row_sum = 0.0f;
if (lane == 0) for (int j=0;j<SK;j++) {
s_vals[j] = expf(s_vals[j] - row_max);
row_sum += s_vals[j];
}
row_sum = wsum(row_sum);
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
if (lane == 0) for (int j=0;j<SK;j++) s_p_vals[j] = s_vals[j];
}
__syncthreads();
// PV MMA (N=64)
{
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD);
for (int kt = 0; kt < 8; kt++) {
for (int i = tid; i < TILE_SZ; i += 128) sPk[i] = 0;
if (tid < 16) {
int c = tid;
int ck = c / 8, lc = c % 8;
sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * LOCAL_MMA_K + c]);
}
__syncthreads();
bf16_t* sv = sV + kt * V_TILE_SZ;
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), BLOCK_MN);
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), HD);
if (tid == 0) umma_ss_f16(tb, dp, dv, idesc_pv, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
}
// ===== Read TMEM using 16x256b format (all lanes) =====
// Each lane gets 4 FP32 values from 16 rows of the column.
// Lane i reads positions i*4+0..3 within the column.
// All 32 lanes together read 128 positions per column.
// We dump columns 0..63 (the PV output).
if (wid == 0) {
for (int col = 0; col < 64; col++) {
uint32_t r0, r1, r2, r3;
tmem_load(tb + col, r0, r1, r2, r3);
asm volatile("tcgen05.wait::ld.sync.aligned;");
// Each lane writes its 4 values
// Position mapping: lane i, register j → position i*4+j in the column
// For column col, the output element index is... unknown (that's what we're mapping)
int base_idx = col * 128 + lane * 4; // column × 128 positions + lane offset
tmem_dump[base_idx + 0] = u32_to_f32(r0);
tmem_dump[base_idx + 1] = u32_to_f32(r1);
tmem_dump[base_idx + 2] = u32_to_f32(r2);
tmem_dump[base_idx + 3] = u32_to_f32(r3);
}
}
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== Full TMEM Layout D dump for PV MMA N=64 ===\n");
const float SCALE = 1.0f / sqrtf((float)HD);
bf16_t* h_q = (bf16_t*)malloc(HD*sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)malloc(SK*HD*sizeof(bf16_t));
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
srand(42);
for (int d=0;d<HD;d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<SK*HD;i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
bf16_t *d_q,*d_k,*d_v;
float* d_tmem_dump;
cudaMalloc(&d_q, HD*sizeof(bf16_t));
cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
// 64 columns × 128 positions per column = 8192 floats
cudaMalloc(&d_tmem_dump, 64 * 128 * sizeof(float));
cudaMemset(d_tmem_dump, 0, 64 * 128 * sizeof(float));
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
int smem = (4+16 + 4*TILE_SZ*2 + 4*TILE_SZ*2 + TILE_SZ*2 + 8*V_TILE_SZ*2 + SK*4 + 256 + 127) & ~127;
cudaFuncSetAttribute(test_tmem_full_dump, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
test_tmem_full_dump<<<1, 128, smem>>>(d_q, d_k, d_v, d_tmem_dump, SCALE);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
float* h_dump = (float*)malloc(64 * 128 * sizeof(float));
cudaMemcpy(h_dump, d_tmem_dump, 64 * 128 * sizeof(float), cudaMemcpyDeviceToHost);
// Compute reference output for row 0
float s[SK];
for (int j=0;j<SK;j++) {
float dot = 0.0f;
for (int d=0;d<HD;d++) dot += bf16_to_f32_host(h_q[d]) * bf16_to_f32_host(h_k[j*HD+d]);
s[j] = dot * SCALE;
}
float mx = -INFINITY;
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
float sm = 0.0f;
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
for (int j=0;j<SK;j++) s[j] /= sm;
float o_ref[HD];
for (int d=0;d<HD;d++) {
float ov = 0.0f;
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32_host(h_v[d*SK+j]);
o_ref[d] = ov;
}
// Find the exact mapping: for each reference output value, find (col, position) in the dump
printf("=== Mapping: output position d -> (col, pos_in_col) ===\n");
for (int d = 0; d < HD; d++) {
float target = o_ref[d];
int best_col = -1, best_pos = -1;
float best_diff = 1e10f;
for (int col = 0; col < 64; col++) {
for (int pos = 0; pos < 128; pos++) {
float val = h_dump[col * 128 + pos];
float diff = fabsf(val - target);
if (diff < best_diff) {
best_diff = diff;
best_col = col;
best_pos = pos;
}
}
}
printf(" d=%2d: ref=%10.6f at (col=%2d, pos=%3d) val=%10.6f diff=%.2e\n",
d, target, best_col, best_pos, h_dump[best_col*128+best_pos], best_diff);
}
// Print the pattern for row 0: which (col, pos) = (n, 0) for n=0..63?
printf("\n=== Row 0 pattern: col vs output d ===\n");
printf("For each column, which output positions (row 0) are at pos 0..3 (lane 0)?\n");
for (int col = 0; col < 16; col++) {
printf(" col %2d pos 0..3: %10.6f %10.6f %10.6f %10.6f\n", col,
h_dump[col*128+0], h_dump[col*128+1], h_dump[col*128+2], h_dump[col*128+3]);
}
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_tmem_dump);
free(h_q); free(h_k); free(h_v); free(h_dump);
return 0;
}