250 lines
10 KiB
Plaintext
250 lines
10 KiB
Plaintext
/**
|
||
* Map TMEM Layout D for PV MMA N=64 using 16x256b reads.
|
||
* Read ALL positions (all lanes) for columns 0-63.
|
||
*/
|
||
|
||
#include <cuda_runtime.h>
|
||
#include <cstdio>
|
||
#include <cstring>
|
||
#include <cmath>
|
||
|
||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||
|
||
using namespace dsv4::kernels::attention;
|
||
|
||
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
||
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
||
|
||
constexpr int HD = 64, SK = 128, BLOCK_MN = 128;
|
||
constexpr int LOCAL_MMA_K = 16;
|
||
constexpr int TILE_SZ = BLOCK_MN * LOCAL_MMA_K;
|
||
constexpr int V_TILE_SZ = (HD / 8) * 2 * 64;
|
||
|
||
__global__ void __launch_bounds__(128)
|
||
test_tmem_full_dump(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||
float* tmem_dump, float scale)
|
||
{
|
||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||
|
||
extern __shared__ char sbuf[];
|
||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||
bf16_t* sK0 = sQ0 + 4 * TILE_SZ;
|
||
bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + 4 * TILE_SZ) + 127) & ~(uintptr_t)127);
|
||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127);
|
||
float* s_p_vals = (float*)(sV + 8 * V_TILE_SZ);
|
||
|
||
// Load Q, K, V (same as debug test)
|
||
for (int kt = 0; kt < 4; kt++) {
|
||
bf16_t* sq = sQ0 + kt * TILE_SZ;
|
||
for (int i = tid; i < TILE_SZ; i += 128) sq[i] = 0;
|
||
for (int d = tid; d < LOCAL_MMA_K; d += 128) {
|
||
int ck = d / 8, lc = d % 8;
|
||
sq[ck * 16 * 64 + lc] = q[kt * LOCAL_MMA_K + d];
|
||
}
|
||
}
|
||
for (int kt = 0; kt < 4; kt++) {
|
||
bf16_t* sk = sK0 + kt * TILE_SZ;
|
||
for (int i = tid; i < TILE_SZ; i += 128) sk[i] = 0;
|
||
for (int r = 0; r < SK; r++) {
|
||
for (int d = tid; d < LOCAL_MMA_K; d += 128) {
|
||
int ck = d / 8, lc = d % 8;
|
||
int tmn = r / 8, lr = r % 8;
|
||
sk[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + kt * LOCAL_MMA_K + d];
|
||
}
|
||
}
|
||
}
|
||
for (int kt = 0; kt < 8; kt++) {
|
||
bf16_t* sv = sV + kt * V_TILE_SZ;
|
||
for (int i = tid; i < V_TILE_SZ; i += 128) sv[i] = 0;
|
||
for (int d = tid; d < HD; d += 128) {
|
||
for (int lr = 0; lr < LOCAL_MMA_K; lr++) {
|
||
int r = kt * LOCAL_MMA_K + lr;
|
||
int g_mn = d / 8, g_k = lr / 8;
|
||
int llr = d % 8, lc = lr % 8;
|
||
sv[g_k * 8 * 64 + g_mn * 64 + llr * 8 + lc] = v[d * SK + r];
|
||
}
|
||
}
|
||
}
|
||
__syncthreads();
|
||
|
||
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
|
||
__syncthreads();
|
||
uint32_t tb = *sTmemBase;
|
||
|
||
// QK GEMM
|
||
{
|
||
uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN);
|
||
for (int kt = 0; kt < 4; kt++) {
|
||
bf16_t* sq = sQ0 + kt * TILE_SZ;
|
||
bf16_t* sk = sK0 + kt * TILE_SZ;
|
||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sq), BLOCK_MN);
|
||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sk), BLOCK_MN);
|
||
if (tid == 0) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
}
|
||
}
|
||
|
||
// Softmax
|
||
if (wid == 0) {
|
||
float s_vals[SK], row_max = -INFINITY;
|
||
for (int n = 0; n < SK / 8; n++) {
|
||
float tmp[8];
|
||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||
: "r"(tb + n*8));
|
||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||
if (lane == 0) for (int c=0;c<8;c++) {
|
||
s_vals[n*8+c] = tmp[c] * scale;
|
||
row_max = fmaxf(row_max, tmp[c] * scale);
|
||
}
|
||
}
|
||
row_max = wmax(row_max);
|
||
float row_sum = 0.0f;
|
||
if (lane == 0) for (int j=0;j<SK;j++) {
|
||
s_vals[j] = expf(s_vals[j] - row_max);
|
||
row_sum += s_vals[j];
|
||
}
|
||
row_sum = wsum(row_sum);
|
||
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
|
||
if (lane == 0) for (int j=0;j<SK;j++) s_p_vals[j] = s_vals[j];
|
||
}
|
||
__syncthreads();
|
||
|
||
// PV MMA (N=64)
|
||
{
|
||
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD);
|
||
for (int kt = 0; kt < 8; kt++) {
|
||
for (int i = tid; i < TILE_SZ; i += 128) sPk[i] = 0;
|
||
if (tid < 16) {
|
||
int c = tid;
|
||
int ck = c / 8, lc = c % 8;
|
||
sPk[ck * 16 * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * LOCAL_MMA_K + c]);
|
||
}
|
||
__syncthreads();
|
||
|
||
bf16_t* sv = sV + kt * V_TILE_SZ;
|
||
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), BLOCK_MN);
|
||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), HD);
|
||
if (tid == 0) umma_ss_f16(tb, dp, dv, idesc_pv, kt > 0);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
}
|
||
}
|
||
|
||
// ===== Read TMEM using 16x256b format (all lanes) =====
|
||
// Each lane gets 4 FP32 values from 16 rows of the column.
|
||
// Lane i reads positions i*4+0..3 within the column.
|
||
// All 32 lanes together read 128 positions per column.
|
||
// We dump columns 0..63 (the PV output).
|
||
if (wid == 0) {
|
||
for (int col = 0; col < 64; col++) {
|
||
uint32_t r0, r1, r2, r3;
|
||
tmem_load(tb + col, r0, r1, r2, r3);
|
||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||
// Each lane writes its 4 values
|
||
// Position mapping: lane i, register j → position i*4+j in the column
|
||
// For column col, the output element index is... unknown (that's what we're mapping)
|
||
int base_idx = col * 128 + lane * 4; // column × 128 positions + lane offset
|
||
tmem_dump[base_idx + 0] = u32_to_f32(r0);
|
||
tmem_dump[base_idx + 1] = u32_to_f32(r1);
|
||
tmem_dump[base_idx + 2] = u32_to_f32(r2);
|
||
tmem_dump[base_idx + 3] = u32_to_f32(r3);
|
||
}
|
||
}
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_dealloc(tb, 128);
|
||
}
|
||
|
||
int main() {
|
||
printf("=== Full TMEM Layout D dump for PV MMA N=64 ===\n");
|
||
const float SCALE = 1.0f / sqrtf((float)HD);
|
||
|
||
bf16_t* h_q = (bf16_t*)malloc(HD*sizeof(bf16_t));
|
||
bf16_t* h_k = (bf16_t*)malloc(SK*HD*sizeof(bf16_t));
|
||
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
|
||
|
||
srand(42);
|
||
for (int d=0;d<HD;d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
|
||
for (int i=0;i<SK*HD;i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
|
||
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
|
||
|
||
bf16_t *d_q,*d_k,*d_v;
|
||
float* d_tmem_dump;
|
||
cudaMalloc(&d_q, HD*sizeof(bf16_t));
|
||
cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
|
||
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
|
||
// 64 columns × 128 positions per column = 8192 floats
|
||
cudaMalloc(&d_tmem_dump, 64 * 128 * sizeof(float));
|
||
cudaMemset(d_tmem_dump, 0, 64 * 128 * sizeof(float));
|
||
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||
|
||
int smem = (4+16 + 4*TILE_SZ*2 + 4*TILE_SZ*2 + TILE_SZ*2 + 8*V_TILE_SZ*2 + SK*4 + 256 + 127) & ~127;
|
||
cudaFuncSetAttribute(test_tmem_full_dump, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
|
||
test_tmem_full_dump<<<1, 128, smem>>>(d_q, d_k, d_v, d_tmem_dump, SCALE);
|
||
|
||
cudaError_t err = cudaDeviceSynchronize();
|
||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||
|
||
float* h_dump = (float*)malloc(64 * 128 * sizeof(float));
|
||
cudaMemcpy(h_dump, d_tmem_dump, 64 * 128 * sizeof(float), cudaMemcpyDeviceToHost);
|
||
|
||
// Compute reference output for row 0
|
||
float s[SK];
|
||
for (int j=0;j<SK;j++) {
|
||
float dot = 0.0f;
|
||
for (int d=0;d<HD;d++) dot += bf16_to_f32_host(h_q[d]) * bf16_to_f32_host(h_k[j*HD+d]);
|
||
s[j] = dot * SCALE;
|
||
}
|
||
float mx = -INFINITY;
|
||
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
|
||
float sm = 0.0f;
|
||
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
|
||
for (int j=0;j<SK;j++) s[j] /= sm;
|
||
float o_ref[HD];
|
||
for (int d=0;d<HD;d++) {
|
||
float ov = 0.0f;
|
||
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32_host(h_v[d*SK+j]);
|
||
o_ref[d] = ov;
|
||
}
|
||
|
||
// Find the exact mapping: for each reference output value, find (col, position) in the dump
|
||
printf("=== Mapping: output position d -> (col, pos_in_col) ===\n");
|
||
for (int d = 0; d < HD; d++) {
|
||
float target = o_ref[d];
|
||
int best_col = -1, best_pos = -1;
|
||
float best_diff = 1e10f;
|
||
for (int col = 0; col < 64; col++) {
|
||
for (int pos = 0; pos < 128; pos++) {
|
||
float val = h_dump[col * 128 + pos];
|
||
float diff = fabsf(val - target);
|
||
if (diff < best_diff) {
|
||
best_diff = diff;
|
||
best_col = col;
|
||
best_pos = pos;
|
||
}
|
||
}
|
||
}
|
||
printf(" d=%2d: ref=%10.6f at (col=%2d, pos=%3d) val=%10.6f diff=%.2e\n",
|
||
d, target, best_col, best_pos, h_dump[best_col*128+best_pos], best_diff);
|
||
}
|
||
|
||
// Print the pattern for row 0: which (col, pos) = (n, 0) for n=0..63?
|
||
printf("\n=== Row 0 pattern: col vs output d ===\n");
|
||
printf("For each column, which output positions (row 0) are at pos 0..3 (lane 0)?\n");
|
||
for (int col = 0; col < 16; col++) {
|
||
printf(" col %2d pos 0..3: %10.6f %10.6f %10.6f %10.6f\n", col,
|
||
h_dump[col*128+0], h_dump[col*128+1], h_dump[col*128+2], h_dump[col*128+3]);
|
||
}
|
||
|
||
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_tmem_dump);
|
||
free(h_q); free(h_k); free(h_v); free(h_dump);
|
||
return 0;
|
||
}
|