175 lines
6.8 KiB
Plaintext
175 lines
6.8 KiB
Plaintext
/**
|
||
* Test: Softmax → PV (1 K-tile) via TS MMA.
|
||
*
|
||
* Write S (random data) to TMEM, softmax reads it, writes P to TMEM,
|
||
* then TS MMA does PV for 1 K-tile.
|
||
* This isolates the softmax→PV path from QK.
|
||
*/
|
||
|
||
#include <cuda_runtime.h>
|
||
#include <cstdio>
|
||
#include <cmath>
|
||
#include <cstdlib>
|
||
#include <cstring>
|
||
|
||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||
|
||
using namespace dsv4::kernels::attention;
|
||
|
||
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
||
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
||
|
||
constexpr int SK = 128, HD = 16, BLOCK_MN = 128;
|
||
|
||
__global__ void __launch_bounds__(128)
|
||
test_softmax_pv(const float* __restrict__ s_input, // (SK,) — row 0 of S
|
||
const bf16_t* __restrict__ v, // (HD, SK)
|
||
float* __restrict__ o_mma,
|
||
float* __restrict__ o_ref,
|
||
float scale)
|
||
{
|
||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||
|
||
extern __shared__ char sbuf[];
|
||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||
|
||
// Load V K-tile 0: (16, 16)
|
||
for (int i = tid; i < 256; i += 128) sV[i] = 0;
|
||
for (int d = tid; d < HD; d += 128) {
|
||
for (int lr = 0; lr < 16; lr++) {
|
||
int ck = d / 8, lc = d % 8;
|
||
int tmn = lr / 8, llr = lr % 8;
|
||
int dst_idx = ck * 2 * 64 + tmn * 64 + llr * 8 + lc;
|
||
sV[dst_idx] = v[d * SK + lr]; // First 16 positions of V
|
||
}
|
||
}
|
||
__syncthreads();
|
||
|
||
// TMEM alloc: 128 cols for S/P, 32 cols for O
|
||
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 256);
|
||
__syncthreads();
|
||
uint32_t tb = *sTmemBase;
|
||
uint32_t tb_o = tb + 128;
|
||
|
||
// Write S row 0 to TMEM directly (bypass QK, use input data)
|
||
// All 128 columns of S: for decode T=1, only row 0 has data
|
||
if (wid == 0) {
|
||
float s_vals[SK];
|
||
// Read S from GMEM (already scaled)
|
||
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] = s_input[j];
|
||
// Softmax
|
||
float row_max = -INFINITY;
|
||
if (lane == 0) for (int j=0;j<SK;j++) row_max = fmaxf(row_max, s_vals[j]);
|
||
row_max = wmax(row_max);
|
||
float row_sum = 0.0f;
|
||
if (lane == 0) for (int j=0;j<SK;j++) { s_vals[j] = expf(s_vals[j] - row_max); row_sum += s_vals[j]; }
|
||
row_sum = wsum(row_sum);
|
||
if (lane == 0) for (int j=0;j<SK;j++) s_vals[j] /= row_sum;
|
||
|
||
// Write P to TMEM (128 columns) using 32x32b.x8
|
||
for (int n = 0; n < SK / 8; n++) {
|
||
float p0=(lane==0)?s_vals[n*8+0]:0, p1=(lane==0)?s_vals[n*8+1]:0;
|
||
float p2=(lane==0)?s_vals[n*8+2]:0, p3=(lane==0)?s_vals[n*8+3]:0;
|
||
float p4=(lane==0)?s_vals[n*8+4]:0, p5=(lane==0)?s_vals[n*8+5]:0;
|
||
float p6=(lane==0)?s_vals[n*8+6]:0, p7=(lane==0)?s_vals[n*8+7]:0;
|
||
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};"
|
||
:: "r"(tb+n*8),
|
||
"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
|
||
}
|
||
tmem_fence_store();
|
||
}
|
||
__syncthreads();
|
||
|
||
// PV GEMM: 1 K-tile
|
||
// A = P cols 0-15 (TMEM), B = V rows 0-15 (SMEM), C = O (TMEM at tb_o)
|
||
{
|
||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||
uint32_t idesc = make_idesc(128, 16);
|
||
if (tid == 0) umma_ts_f16(tb_o, tb, dv, idesc, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
}
|
||
|
||
// Read O row 0 from TMEM
|
||
if (wid == 0) {
|
||
float o_vals[HD];
|
||
for (int n = 0; n < HD / 8; n++) {
|
||
float tmp[8];
|
||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||
: "r"(tb_o + n*8));
|
||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * 2.0f;
|
||
}
|
||
if (lane == 0) for (int d=0;d<HD;d++) o_mma[d] = o_vals[d];
|
||
}
|
||
__syncthreads();
|
||
|
||
// Scalar reference: O = P × V (first 16 positions of V)
|
||
if (tid == 0) {
|
||
float s[SK];
|
||
for (int j=0;j<SK;j++) s[j] = s_input[j];
|
||
float mx = -INFINITY;
|
||
for (int j=0;j<SK;j++) mx = fmaxf(mx, s[j]);
|
||
float sm = 0.0f;
|
||
for (int j=0;j<SK;j++) { s[j] = expf(s[j]-mx); sm += s[j]; }
|
||
for (int j=0;j<SK;j++) s[j] /= sm;
|
||
for (int d=0;d<HD;d++) {
|
||
float ov = 0.0f;
|
||
for (int j=0;j<16;j++) ov += s[j] * bf16_to_f32(v[d*SK+j]); // Only first 16 positions
|
||
o_ref[d] = ov;
|
||
}
|
||
}
|
||
|
||
if (wid == 0) tmem_dealloc(tb, 256);
|
||
}
|
||
|
||
int main() {
|
||
printf("=== Softmax → PV (1 K-tile) ===\n");
|
||
const float SCALE = 1.0f / sqrtf((float)HD);
|
||
|
||
// Generate S = Q·K^T · scale (simulate QK output)
|
||
float* h_s = (float*)malloc(SK*sizeof(float));
|
||
bf16_t* h_v = (bf16_t*)malloc(HD*SK*sizeof(bf16_t));
|
||
float* h_o_mma = (float*)calloc(HD, sizeof(float));
|
||
float* h_o_ref = (float*)calloc(HD, sizeof(float));
|
||
|
||
srand(42);
|
||
// S values (simulate random attention scores)
|
||
for (int j=0;j<SK;j++) h_s[j] = ((float)(rand()%100)/100.0f - 0.5f) * 2.0f;
|
||
for (int i=0;i<HD*SK;i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
|
||
|
||
float *d_s; bf16_t *d_v; float *d_o_mma, *d_o_ref;
|
||
cudaMalloc(&d_s, SK*sizeof(float));
|
||
cudaMalloc(&d_v, HD*SK*sizeof(bf16_t));
|
||
cudaMalloc(&d_o_mma, HD*sizeof(float));
|
||
cudaMalloc(&d_o_ref, HD*sizeof(float));
|
||
cudaMemcpy(d_s, h_s, SK*sizeof(float), cudaMemcpyHostToDevice);
|
||
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||
|
||
int smem = (4+16 + 256*2 + 256 + 127) & ~127;
|
||
test_softmax_pv<<<1, 128, smem>>>(d_s, d_v, d_o_mma, d_o_ref, SCALE);
|
||
|
||
cudaError_t err = cudaDeviceSynchronize();
|
||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||
|
||
cudaMemcpy(h_o_mma, d_o_mma, HD*sizeof(float), cudaMemcpyDeviceToHost);
|
||
cudaMemcpy(h_o_ref, d_o_ref, HD*sizeof(float), cudaMemcpyDeviceToHost);
|
||
|
||
printf("O[0..15] MMA: "); for(int d=0;d<HD;d++) printf("%.6f ",h_o_mma[d]); printf("\n");
|
||
printf("O[0..15] ref: "); for(int d=0;d<HD;d++) printf("%.6f ",h_o_ref[d]); printf("\n");
|
||
|
||
float cos_sim=0,na=0,nb=0;
|
||
for (int d=0;d<HD;d++) { float a=h_o_mma[d],b=h_o_ref[d]; cos_sim+=a*b; na+=a*a; nb+=b*b; }
|
||
cos_sim /= (sqrtf(na)*sqrtf(nb)+1e-10f);
|
||
printf("cosine: %.8f\n", cos_sim);
|
||
printf("Test %s\n", cos_sim > 0.9f ? "PASSED" : "FAILED");
|
||
|
||
cudaFree(d_s); cudaFree(d_v); cudaFree(d_o_mma); cudaFree(d_o_ref);
|
||
free(h_s); free(h_v); free(h_o_mma); free(h_o_ref);
|
||
return cos_sim > 0.9f ? 0 : 1;
|
||
}
|