P5: fix TMA multitile test (include cuda.h first, proper SMEM calc)
This commit is contained in:
@@ -1,47 +1,40 @@
|
||||
/**
|
||||
* P5: Test multi-tile TMA FMHA kernel with proper alignment.
|
||||
*/
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_tma.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
static float hbf16_to_f32(uint16_t h) {
|
||||
uint32_t u = ((uint32_t)h) << 16;
|
||||
float f; memcpy(&f, &u, 4); return f;
|
||||
}
|
||||
static uint16_t hf32_to_bf16(float f) {
|
||||
uint32_t u; memcpy(&u, &f, 4); return (uint16_t)(u >> 16);
|
||||
}
|
||||
static float hbf16_to_f32(uint16_t h) { uint32_t u = ((uint32_t)h) << 16; float f; memcpy(&f, &u, 4); return f; }
|
||||
static uint16_t hf32_to_bf16(float f) { uint32_t u; memcpy(&u, &f, 4); return (uint16_t)(u >> 16); }
|
||||
|
||||
int main() {
|
||||
constexpr int HD = 64;
|
||||
constexpr int SK = 256;
|
||||
const float SCALE = 1.0f / sqrtf((float)HD);
|
||||
|
||||
// Allocate 128B-aligned GPU memory
|
||||
bf16_t *d_q, *d_k, *d_v, *d_o;
|
||||
float *d_lse;
|
||||
cudaMalloc(&d_q, HD * 2 + 128);
|
||||
cudaMalloc(&d_k, SK * HD * 2 + 128);
|
||||
cudaMalloc(&d_v, HD * SK * 2 + 128);
|
||||
cudaMalloc(&d_o, HD * 2 + 128);
|
||||
cudaMalloc(&d_lse, 4 + 128);
|
||||
bf16_t *d_q_raw, *d_k_raw, *d_v_raw, *d_o_raw;
|
||||
float *d_lse_raw;
|
||||
cudaMalloc(&d_q_raw, HD * 2 + 128);
|
||||
cudaMalloc(&d_k_raw, SK * HD * 2 + 128);
|
||||
cudaMalloc(&d_v_raw, HD * SK * 2 + 128);
|
||||
cudaMalloc(&d_o_raw, HD * 2 + 128);
|
||||
cudaMalloc(&d_lse_raw, 4 + 128);
|
||||
|
||||
// Align pointers
|
||||
d_q = (bf16_t*)(((uintptr_t)d_q + 127) & ~(uintptr_t)127);
|
||||
d_k = (bf16_t*)(((uintptr_t)d_k + 127) & ~(uintptr_t)127);
|
||||
d_v = (bf16_t*)(((uintptr_t)d_v + 127) & ~(uintptr_t)127);
|
||||
d_o = (bf16_t*)(((uintptr_t)d_o + 127) & ~(uintptr_t)127);
|
||||
d_lse = (float*)(((uintptr_t)d_lse + 127) & ~(uintptr_t)127);
|
||||
bf16_t *d_q = (bf16_t*)(((uintptr_t)d_q_raw + 127) & ~(uintptr_t)127);
|
||||
bf16_t *d_k = (bf16_t*)(((uintptr_t)d_k_raw + 127) & ~(uintptr_t)127);
|
||||
bf16_t *d_v = (bf16_t*)(((uintptr_t)d_v_raw + 127) & ~(uintptr_t)127);
|
||||
bf16_t *d_o = (bf16_t*)(((uintptr_t)d_o_raw + 127) & ~(uintptr_t)127);
|
||||
float *d_lse = (float*)(((uintptr_t)d_lse_raw + 127) & ~(uintptr_t)127);
|
||||
|
||||
// Fill with random data
|
||||
srand(42);
|
||||
bf16_t h_q[HD], h_k[SK * HD], h_v[HD * SK];
|
||||
for (int d = 0; d < HD; d++) h_q[d] = hf32_to_bf16((float)(rand() % 100) / 100.0f);
|
||||
@@ -51,17 +44,10 @@ int main() {
|
||||
cudaMemcpy(d_q, h_q, HD * 2, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_k, h_k, SK * HD * 2, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_v, h_v, HD * SK * 2, cudaMemcpyHostToDevice);
|
||||
cudaMemset(d_o, 0, HD * 2);
|
||||
cudaMemset(d_lse, 0, 4);
|
||||
|
||||
// Create TMA descriptors
|
||||
CUtensorMap h_tma_k, h_tma_v;
|
||||
if (!create_tma_desc_2d_bf16(&h_tma_k, d_k, SK, HD, 128, 16)) {
|
||||
printf("K TMA desc creation FAILED\n"); return 1;
|
||||
}
|
||||
if (!create_tma_desc_2d_bf16(&h_tma_v, d_v, HD, SK, 16, 16)) {
|
||||
printf("V TMA desc creation FAILED\n"); return 1;
|
||||
}
|
||||
if (!create_tma_desc_2d_bf16(&h_tma_k, d_k, SK, HD, 128, 16)) { printf("K TMA FAILED\n"); return 1; }
|
||||
if (!create_tma_desc_2d_bf16(&h_tma_v, d_v, HD, SK, 16, 16)) { printf("V TMA FAILED\n"); return 1; }
|
||||
|
||||
CUtensorMap *d_tma_k, *d_tma_v;
|
||||
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
|
||||
@@ -69,79 +55,67 @@ int main() {
|
||||
cudaMemcpy(d_tma_k, &h_tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_tma_v, &h_tma_v, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
|
||||
// Launch multi-tile kernel
|
||||
FmhaTmaMultiRowMultiTileParams params;
|
||||
params.q = d_q;
|
||||
params.tma_k = d_tma_k;
|
||||
params.tma_v = d_tma_v;
|
||||
params.o = d_o;
|
||||
params.lse = d_lse;
|
||||
params.s_k = SK;
|
||||
params.T = 1;
|
||||
params.n_h = 1;
|
||||
params.scale = SCALE;
|
||||
params.q_head_stride = 0;
|
||||
params.q_batch_stride = 0;
|
||||
params.o_head_stride = 0;
|
||||
params.o_batch_stride = 0;
|
||||
params.lse_head_stride = 0;
|
||||
params.lse_batch_stride = 0;
|
||||
params.q = d_q; params.tma_k = d_tma_k; params.tma_v = d_tma_v;
|
||||
params.o = d_o; params.lse = d_lse;
|
||||
params.s_k = SK; params.T = 1; params.n_h = 1; params.scale = SCALE;
|
||||
params.q_head_stride = 0; params.q_batch_stride = 0;
|
||||
params.o_head_stride = 0; params.o_batch_stride = 0;
|
||||
params.lse_head_stride = 0; params.lse_batch_stride = 0;
|
||||
|
||||
// Compute SMEM (match kernel layout)
|
||||
constexpr int HD_CHUNK = 256;
|
||||
constexpr int TILE_SZ = 128 * 16;
|
||||
constexpr int V_SUB_SZ = 16 * 16;
|
||||
int hc = (HD <= 256) ? HD : HD_CHUNK;
|
||||
size_t off = 0;
|
||||
off += 4; off = (off+127)&~(size_t)127; // tmembase
|
||||
off += 16; off = (off+127)&~(size_t)127; // mbar
|
||||
off += TILE_SZ*2; off = (off+127)&~(size_t)127; // tmabuf
|
||||
off += TILE_SZ*2; off = (off+127)&~(size_t)127; // q0
|
||||
off += TILE_SZ*2; off = (off+127)&~(size_t)127; // k0
|
||||
off += TILE_SZ*2; off = (off+127)&~(size_t)127; // pk
|
||||
off += V_SUB_SZ*2; off = (off+127)&~(size_t)127; // v
|
||||
off += 128*hc*4; off += 128*4; off += 128*4; off += 128*4; off += 128*4; off += 256;
|
||||
int smem = (int)((off + 127) & ~(size_t)127);
|
||||
|
||||
int smem = 4 + 16 + 128*16*2 + 128*16*2 + 128*16*2 + 128*16*2 + 16*16*2
|
||||
+ 128*256*4 + 128*4 + 128*4 + 128*4 + 128*4 + 256 + 127;
|
||||
smem &= ~127;
|
||||
dim3 grid(1, 1, 1);
|
||||
dim3 block(NTHREADS);
|
||||
|
||||
cudaFuncSetAttribute(fmha_6warp_tma_multirow_multitile_kernel<HD>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
|
||||
|
||||
fmha_6warp_tma_multirow_multitile_kernel<HD><<<grid, block, smem>>>(params);
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("Kernel FAILED: %s\n", cudaGetErrorString(err));
|
||||
return 1;
|
||||
}
|
||||
if (err != cudaSuccess) { printf("Kernel FAILED: %s\n", cudaGetErrorString(err)); return 1; }
|
||||
|
||||
// Compare with CPU reference
|
||||
float h_o_ref[HD], h_lse_ref;
|
||||
float row_max = -INFINITY;
|
||||
// CPU reference
|
||||
float h_o_ref[HD], row_max = -INFINITY;
|
||||
for (int j = 0; j < SK; j++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < HD; d++) dot += hbf16_to_f32(h_q[d]) * hbf16_to_f32(h_k[j * HD + d]);
|
||||
float dot = 0;
|
||||
for (int d = 0; d < HD; d++) dot += hbf16_to_f32(h_q[d]) * hbf16_to_f32(h_k[j*HD+d]);
|
||||
dot *= SCALE;
|
||||
if (dot > row_max) row_max = dot;
|
||||
}
|
||||
float row_sum = 0.0f;
|
||||
for (int d = 0; d < HD; d++) h_o_ref[d] = 0.0f;
|
||||
float row_sum = 0;
|
||||
for (int d = 0; d < HD; d++) h_o_ref[d] = 0;
|
||||
for (int j = 0; j < SK; j++) {
|
||||
float dot = 0.0f;
|
||||
for (int d = 0; d < HD; d++) dot += hbf16_to_f32(h_q[d]) * hbf16_to_f32(h_k[j * HD + d]);
|
||||
dot *= SCALE;
|
||||
float p = expf(dot - row_max);
|
||||
float dot = 0;
|
||||
for (int d = 0; d < HD; d++) dot += hbf16_to_f32(h_q[d]) * hbf16_to_f32(h_k[j*HD+d]);
|
||||
float p = expf(dot * SCALE - row_max);
|
||||
row_sum += p;
|
||||
for (int d = 0; d < HD; d++) h_o_ref[d] += p * hbf16_to_f32(h_v[d * SK + j]);
|
||||
for (int d = 0; d < HD; d++) h_o_ref[d] += p * hbf16_to_f32(h_v[d*SK+j]);
|
||||
}
|
||||
for (int d = 0; d < HD; d++) h_o_ref[d] /= row_sum;
|
||||
h_lse_ref = logf(row_sum) + row_max;
|
||||
|
||||
bf16_t h_o[HD];
|
||||
float h_lse;
|
||||
cudaMemcpy(h_o, d_o, HD * 2, cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(&h_lse, d_lse, 4, cudaMemcpyDeviceToHost);
|
||||
|
||||
float cos = 0, na = 0, nb = 0;
|
||||
for (int d = 0; d < HD; d++) {
|
||||
float a = h_o_ref[d], b = hbf16_to_f32(h_o[d]);
|
||||
cos += a * b; na += a * a; nb += b * b;
|
||||
}
|
||||
for (int d = 0; d < HD; d++) { float a = h_o_ref[d], b = hbf16_to_f32(h_o[d]); cos += a*b; na += a*a; nb += b*b; }
|
||||
cos /= sqrtf(na * nb + 1e-30f);
|
||||
|
||||
printf("Multi-tile TMA FMHA (HD=%d, SK=%d):\n", HD, SK);
|
||||
printf(" LSE: kernel=%.4f ref=%.4f\n", h_lse, h_lse_ref);
|
||||
printf(" Cosine: %.6f\n", cos);
|
||||
printf(" %s\n", cos >= 0.999990 ? "PASS" : "FAIL");
|
||||
|
||||
printf("Multi-tile TMA FMHA (HD=%d, SK=%d): cos=%.6f %s\n", HD, SK, cos, cos >= 0.999990 ? "PASS" : "FAIL");
|
||||
cudaFree(d_q_raw); cudaFree(d_k_raw); cudaFree(d_v_raw); cudaFree(d_o_raw); cudaFree(d_lse_raw);
|
||||
cudaFree(d_tma_k); cudaFree(d_tma_v);
|
||||
return (cos >= 0.999990) ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user