test: PV-only GEMM — isolate PV from full FMHA pipeline
This commit is contained in:
154
tests/unit/test_pv_only.cu
Normal file
154
tests/unit/test_pv_only.cu
Normal file
@@ -0,0 +1,154 @@
|
||||
/**
|
||||
* Test PV GEMM only: feed known P values, load V directly, verify output.
|
||||
* Uses the same 4-warp pattern as test_fmha_gen.cu.
|
||||
* Skips QK and softmax entirely — just PV → O.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
#ifndef HD_VAL
|
||||
#define HD_VAL 64
|
||||
#endif
|
||||
|
||||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
||||
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
||||
|
||||
constexpr int HD = HD_VAL;
|
||||
constexpr int SK = 128;
|
||||
constexpr int BLOCK_MN = 128;
|
||||
constexpr int MMA_K = MMA_K_BF16;
|
||||
constexpr int NKT_PV = SK / MMA_K;
|
||||
constexpr int N_NSUB = HD / 16;
|
||||
constexpr int TILE_SZ = BLOCK_MN * MMA_K;
|
||||
constexpr int TMEM_N = (HD <= 128) ? 128 : 256;
|
||||
constexpr int V_SUB_SZ = 256; // (16,16) canonical
|
||||
constexpr int CORES_MN = 16;
|
||||
constexpr int CORES_MN_V = 2;
|
||||
|
||||
__global__ void __launch_bounds__(128)
|
||||
test_pv_only_kernel(
|
||||
float* __restrict__ out, // (HD,) — PV output
|
||||
const bf16_t* __restrict__ v, // (HD, SK)
|
||||
const float* __restrict__ p_input // (SK,) — pre-computed P values
|
||||
) {
|
||||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||||
|
||||
extern __shared__ __align__(128) char sbuf[];
|
||||
size_t off = 0;
|
||||
uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off = 4;
|
||||
off = (off + 15) & ~(size_t)15;
|
||||
bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t);
|
||||
float* s_p_vals = (float*)(sbuf + off); off += SK * sizeof(float);
|
||||
|
||||
// Load P values into SMEM
|
||||
for (int j = tid; j < SK; j += 128) s_p_vals[j] = p_input[j];
|
||||
__syncthreads();
|
||||
|
||||
// TMEM alloc
|
||||
if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N);
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// PV GEMM — EXACT pattern from test_fmha_gen.cu
|
||||
{
|
||||
uint32_t idesc_pv16 = make_idesc(BLOCK_MN, 16);
|
||||
uint64_t dp = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sPk), BLOCK_MN);
|
||||
|
||||
for (int n = 0; n < N_NSUB; n++) {
|
||||
int d_base = n * 16;
|
||||
for (int kt = 0; kt < NKT_PV; kt++) {
|
||||
for (int i = tid; i < TILE_SZ; i += 128) sPk[i] = 0;
|
||||
if (tid < 16) {
|
||||
int c = tid;
|
||||
int ck = c / 8, lc = c % 8;
|
||||
sPk[ck * CORES_MN * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * MMA_K + c]);
|
||||
}
|
||||
for (int i = tid; i < V_SUB_SZ; i += 128) sV[i] = 0;
|
||||
for (int dd = tid; dd < 16; dd += 128) {
|
||||
for (int lr = 0; lr < MMA_K; lr++) {
|
||||
int r = kt * MMA_K + lr;
|
||||
int g_mn = dd / 8, g_k = lr / 8, llr = dd % 8, lc = lr % 8;
|
||||
sV[g_k * CORES_MN_V * 64 + g_mn * 64 + llr * 8 + lc] = v[(d_base + dd) * SK + r];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||||
if (tid == 0) umma_ss_f16(tb + n * 16, dp, dv, idesc_pv16, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read from TMEM
|
||||
if (wid == 0) {
|
||||
float o_vals[HD];
|
||||
for (int n = 0; n < HD / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c = 0; c < 8; c++) o_vals[n * 8 + c] = tmp[c];
|
||||
}
|
||||
if (lane == 0) for (int d = 0; d < HD; d++) out[d] = o_vals[d];
|
||||
}
|
||||
|
||||
if (wid == 0) tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("PV-Only Test (HD=%d, SK=%d)\n", HD);
|
||||
|
||||
bf16_t* h_v = (bf16_t*)malloc(HD * SK * sizeof(bf16_t));
|
||||
float* h_p = (float*)malloc(SK * sizeof(float));
|
||||
srand(42);
|
||||
for (int i = 0; i < HD * SK; i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
|
||||
// Simple P: uniform (all 1/SK)
|
||||
for (int j = 0; j < SK; j++) h_p[j] = 1.0f / SK;
|
||||
|
||||
bf16_t *d_v; float *d_p, *d_out;
|
||||
cudaMalloc(&d_v, HD * SK * sizeof(bf16_t));
|
||||
cudaMalloc(&d_p, SK * sizeof(float));
|
||||
cudaMalloc(&d_out, HD * sizeof(float));
|
||||
cudaMemcpy(d_v, h_v, HD * SK * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_p, h_p, SK * sizeof(float), cudaMemcpyHostToDevice);
|
||||
|
||||
int smem = 4 + 16 + TILE_SZ + V_SUB_SZ*2 + SK*4 + 256;
|
||||
cudaFuncSetAttribute(test_pv_only_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
|
||||
test_pv_only_kernel<<<1, 128, smem>>>(d_out, d_v, d_p);
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||||
|
||||
float* h_out = (float*)malloc(HD * sizeof(float));
|
||||
cudaMemcpy(h_out, d_out, HD * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
|
||||
// Reference: O = P @ V where P = 1/SK, so O[d] = (1/SK) * sum_j V[d,j]
|
||||
int fail = 0; float max_rel = 0;
|
||||
for (int d = 0; d < HD; d++) {
|
||||
float ref = 0.0f;
|
||||
for (int j = 0; j < SK; j++) ref += bf16_to_f32_host(h_v[d * SK + j]) / SK;
|
||||
float got = h_out[d];
|
||||
float rel = fabsf(ref) > 1e-4f ? fabsf(got - ref) / fabsf(ref) : fabsf(got - ref);
|
||||
if (rel > max_rel) max_rel = rel;
|
||||
if (rel > 0.01f && fail < 5) printf(" d=%d: ref=%.6f got=%.6f\n", d, ref, got);
|
||||
if (rel > 0.01f) fail++;
|
||||
}
|
||||
printf("Max rel err: %.8f, failures: %d\n", max_rel, fail);
|
||||
printf("%s\n", fail == 0 ? "PASSED" : "FAILED");
|
||||
return fail == 0 ? 0 : 1;
|
||||
}
|
||||
Reference in New Issue
Block a user