FMHA with SMEM-P approach: PV via SS MMA avoids TMEM layout mismatch
This commit is contained in:
265
tests/unit/test_fmha_smem_p.cu
Normal file
265
tests/unit/test_fmha_smem_p.cu
Normal file
@@ -0,0 +1,265 @@
|
||||
/**
|
||||
* Full FMHA HD=16, SK=128 — PV via SS MMA (SMEM-P approach)
|
||||
*
|
||||
* Pipeline: Q×K^T (SS) → softmax (TMEM read → SMEM write) → P×V (SS) → epilogue
|
||||
*
|
||||
* Key insight: the tcgen05.mma TS A-operand TMEM layout (Layout A) does NOT
|
||||
* match the 32x32b store format. Using SS MMA for both QK and PV avoids the
|
||||
* TMEM layout issue entirely, because both operands come from SMEM where we
|
||||
* control the canonical K-major layout.
|
||||
*
|
||||
* This is the SMEM-P approach, similar to what CuTeDSL uses for hd > 64,
|
||||
* but applied at all head dims for the raw CUDA path.
|
||||
*
|
||||
* SMEM layout:
|
||||
* sQ: (128, 16) — Q K-tile
|
||||
* sK: (128, 16) — K K-tile
|
||||
* sP: (128, 128) — softmax output, written in canonical K-major layout
|
||||
* sV: 8 × (16, 16) — V K-tiles
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
||||
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
||||
|
||||
constexpr int HD = 16, SK = 128, BLOCK_MN = 128;
|
||||
constexpr int NKT_QK = HD / MMA_K_BF16; // 1
|
||||
constexpr int NKT_PV = SK / MMA_K_BF16; // 8
|
||||
constexpr int TMEM_N = 128; // Just S and O, no P in TMEM
|
||||
constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16;
|
||||
|
||||
__global__ void __launch_bounds__(128)
|
||||
test_fmha_smem_p(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
|
||||
const bf16_t* __restrict__ v, bf16_t* __restrict__ o_out,
|
||||
float* __restrict__ o_scalar, float scale)
|
||||
{
|
||||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||||
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||||
bf16_t* sK0 = sQ0 + TILE_SZ;
|
||||
// sP: softmax output in canonical (128, 128) layout
|
||||
// (128, 128): CORES_MN=16, CORES_K=16
|
||||
// Each core: 64 BF16. Total: 16*16*64 = 16384 BF16 = 32768 bytes
|
||||
bf16_t* sP = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127);
|
||||
// sV: 8 K-tiles of (16, 16)
|
||||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sP + 128 * SK) + 127) & ~(uintptr_t)127);
|
||||
|
||||
// Load Q, K
|
||||
write_q_to_smem<HD>(sQ0, q);
|
||||
write_k_to_smem<SK, HD>(sK0, k);
|
||||
|
||||
// Load V K-tiles
|
||||
for (int kt = 0; kt < NKT_PV; kt++) {
|
||||
bf16_t* sv = sV + kt * 256;
|
||||
for (int i = tid; i < 256; i += 128) sv[i] = 0;
|
||||
for (int d = tid; d < HD; d += 128) {
|
||||
for (int lr = 0; lr < MMA_K_BF16; lr++) {
|
||||
int r = kt * MMA_K_BF16 + lr;
|
||||
int ck = d / 8, lc = d % 8;
|
||||
int tmn = lr / 8, llr = lr % 8;
|
||||
int dst_idx = ck * 2 * 64 + tmn * 64 + llr * 8 + lc;
|
||||
sv[dst_idx] = v[d * SK + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// TMEM alloc: 128 columns for S
|
||||
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// ===== STEP 1: QK GEMM (SS) =====
|
||||
{
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN);
|
||||
uint32_t idesc_qk = make_idesc(BLOCK_MN, BLOCK_MN);
|
||||
for (int kt = 0; kt < NKT_QK; kt++) {
|
||||
if (tid == 0) umma_ss_f16(tb, dq, dk, idesc_qk, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// ===== STEP 2: Softmax — read S from TMEM, write P to SMEM =====
|
||||
if (wid == 0) {
|
||||
float s_vals[SK], row_max = -INFINITY;
|
||||
for (int n = 0; n < SK / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) {
|
||||
s_vals[n*8+c] = tmp[c] * scale;
|
||||
row_max = fmaxf(row_max, tmp[c] * scale);
|
||||
}
|
||||
}
|
||||
row_max = wmax(row_max);
|
||||
float row_sum = 0.0f;
|
||||
if (lane == 0) for (int j=0;j<SK;j++) {
|
||||
s_vals[j] = expf(s_vals[j] - row_max);
|
||||
row_sum += s_vals[j];
|
||||
}
|
||||
row_sum = wsum(row_sum);
|
||||
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
|
||||
|
||||
// Write P to SMEM in canonical (128, 128) layout
|
||||
// Only row 0 has data (decode T=1)
|
||||
// Row 0 in canonical: core_mn=0, local_r=0
|
||||
// P[0, j] → core_k=j/8, core_mn=0, local_r=0, local_c=j%8
|
||||
// dst_idx = core_k * 16 * 64 + 0 * 64 + 0 * 8 + local_c
|
||||
if (lane == 0) {
|
||||
// Zero sP first (only row 0 area needs clearing for decode,
|
||||
// but for safety zero the whole thing)
|
||||
// Actually, we need to zero the WHOLE sP because MMA reads all 128 rows
|
||||
// Let's have all lanes participate in the zeroing
|
||||
}
|
||||
// All 32 lanes zero sP
|
||||
for (int i = lane; i < 128 * SK; i += 32) sP[i] = 0;
|
||||
// Then lane 0 writes row 0
|
||||
if (lane == 0) {
|
||||
for (int j = 0; j < SK; j++) {
|
||||
int core_k = j / 8, lc = j % 8;
|
||||
int dst_idx = core_k * 16 * 64 + 0 * 64 + 0 * 8 + lc;
|
||||
sP[dst_idx] = f32_to_bf16(s_vals[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads(); // Ensure sP is visible to all warps
|
||||
|
||||
// ===== STEP 3: PV GEMM (SS) =====
|
||||
// P(128, 128) × V(128, 16) → O(128, 16)
|
||||
// 8 K-tiles: A = P's kt-th 16 columns (SMEM), B = V's kt-th K-tile (SMEM)
|
||||
// P in canonical (128, 128): each (128, 16) K-tile is 2048 BF16
|
||||
// The kt-th K-tile starts at offset kt * 2048 in the canonical layout
|
||||
{
|
||||
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD); // M=128, N=16
|
||||
|
||||
for (int kt = 0; kt < NKT_PV; kt++) {
|
||||
// A = P's kt-th (128, 16) K-tile in SMEM
|
||||
// In canonical (128, 128), the kt-th (128, 16) K-tile:
|
||||
// CORES_K = 16, each (128, 16) K-tile has 2 core columns
|
||||
// The kt-th (128, 16) starts at core_k=kt
|
||||
// In the flat SMEM, core_k starts at offset core_k * 16 * 64 = kt * 1024 BF16
|
||||
bf16_t* sp = sP + kt * 1024; // 16 core_mn * 64 = 1024 BF16 per core_k group
|
||||
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sp), BLOCK_MN);
|
||||
|
||||
// B = V's kt-th (16, 16) K-tile
|
||||
bf16_t* sv = sV + kt * 256;
|
||||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), 16);
|
||||
|
||||
bool accumulate = (kt > 0);
|
||||
if (tid == 0) umma_ss_f16(tb, dp, dv, idesc_pv, accumulate);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// ===== STEP 4: Epilogue — read O from TMEM =====
|
||||
if (wid == 0) {
|
||||
float o_vals[HD];
|
||||
for (int n = 0; n < HD / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * 2.0f; // Undo MMA 0.5 scale
|
||||
}
|
||||
if (lane == 0) for (int d=0;d<HD;d++) o_out[d] = f32_to_bf16(o_vals[d]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Scalar reference
|
||||
if (tid == 0) {
|
||||
float s[SK];
|
||||
for (int j=0;j<SK;j++) {
|
||||
float dot = 0.0f;
|
||||
for (int d=0;d<HD;d++) dot += bf16_to_f32(q[d]) * bf16_to_f32(k[j*HD+d]);
|
||||
s[j] = dot * scale;
|
||||
}
|
||||
float mx = -INFINITY;
|
||||
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
|
||||
float sm = 0.0f;
|
||||
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
|
||||
for (int j=0;j<SK;j++) s[j] /= sm;
|
||||
for (int d=0;d<HD;d++) {
|
||||
float ov = 0.0f;
|
||||
for (int j=0;j<SK;j++) ov += s[j] * bf16_to_f32(v[d*SK+j]);
|
||||
o_scalar[d] = ov;
|
||||
}
|
||||
}
|
||||
|
||||
if (wid == 0) tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("=== Full FMHA HD=16 with SMEM-P (PV via SS MMA) ===\n");
|
||||
const float SCALE = 1.0f / sqrtf((float)HD);
|
||||
|
||||
bf16_t* h_q = (bf16_t*)malloc(HD*sizeof(bf16_t));
|
||||
bf16_t* h_k = (bf16_t*)malloc(SK*HD*sizeof(bf16_t));
|
||||
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
|
||||
bf16_t* h_o = (bf16_t*)calloc(HD, sizeof(bf16_t));
|
||||
float* h_o_scalar = (float*)calloc(HD, sizeof(float));
|
||||
|
||||
srand(42);
|
||||
for (int d=0;d<HD;d++) h_q[d] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
|
||||
for (int i=0;i<SK*HD;i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
|
||||
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f-0.5f);
|
||||
|
||||
bf16_t *d_q,*d_k,*d_v,*d_o; float *d_o_scalar;
|
||||
cudaMalloc(&d_q, HD*sizeof(bf16_t));
|
||||
cudaMalloc(&d_k, SK*HD*sizeof(bf16_t));
|
||||
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
|
||||
cudaMalloc(&d_o, HD*sizeof(bf16_t));
|
||||
cudaMalloc(&d_o_scalar, HD*sizeof(float));
|
||||
cudaMemcpy(d_q, h_q, HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// SMEM: tmem(4+12) + sQ(4096) + sK(4096) + sP(32768) + sV(4096) + alignment
|
||||
int smem = (4+16 + TILE_SZ*2 + 128*SK*2 + NKT_PV*256*2 + 256 + 127) & ~127;
|
||||
printf("SMEM requested: %d bytes (%.1f KB)\n", smem, smem/1024.0f);
|
||||
test_fmha_smem_p<<<1, 128, smem>>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||||
|
||||
cudaMemcpy(h_o, d_o, HD*sizeof(bf16_t), cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(h_o_scalar, d_o_scalar, HD*sizeof(float), cudaMemcpyDeviceToHost);
|
||||
|
||||
printf("O[0..15] MMA: "); for(int d=0;d<HD;d++) printf("%.6f ",bf16_to_f32_host(h_o[d])); printf("\n");
|
||||
printf("O[0..15] ref: "); for(int d=0;d<HD;d++) printf("%.6f ",h_o_scalar[d]); printf("\n");
|
||||
|
||||
float max_diff=0, max_val=0;
|
||||
for (int d=0;d<HD;d++) {
|
||||
max_diff = fmaxf(max_diff, fabsf(bf16_to_f32_host(h_o[d]) - h_o_scalar[d]));
|
||||
max_val = fmaxf(max_val, fabsf(h_o_scalar[d]));
|
||||
}
|
||||
float rel_err = max_val>0 ? max_diff/max_val : max_diff;
|
||||
float cos_sim=0,na=0,nb=0;
|
||||
for (int d=0;d<HD;d++) { float a=bf16_to_f32_host(h_o[d]),b=h_o_scalar[d]; cos_sim+=a*b; na+=a*a; nb+=b*b; }
|
||||
cos_sim /= (sqrtf(na)*sqrtf(nb)+1e-10f);
|
||||
printf("Max rel err: %.8f | cosine: %.8f\n", rel_err, cos_sim);
|
||||
printf("Test %s\n", cos_sim > 0.999f ? "PASSED" : "FAILED");
|
||||
|
||||
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_o_scalar);
|
||||
free(h_q); free(h_k); free(h_v); free(h_o); free(h_o_scalar);
|
||||
return cos_sim > 0.999f ? 0 : 1;
|
||||
}
|
||||
Reference in New Issue
Block a user