diff --git a/tests/unit/test_p5_tma_multitile.cu b/tests/unit/test_p5_tma_multitile.cu index 32eedfc3..af9c22a0 100644 --- a/tests/unit/test_p5_tma_multitile.cu +++ b/tests/unit/test_p5_tma_multitile.cu @@ -1,47 +1,40 @@ /** * P5: Test multi-tile TMA FMHA kernel with proper alignment. */ +#include +#include #include "dsv4/kernels/attention/fmha_common.cuh" #include "dsv4/kernels/attention/fmha_umma_desc.cuh" #include "dsv4/kernels/attention/fmha_tma.cuh" #include "dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh" - #include #include #include using namespace dsv4::kernels::attention; -static float hbf16_to_f32(uint16_t h) { - uint32_t u = ((uint32_t)h) << 16; - float f; memcpy(&f, &u, 4); return f; -} -static uint16_t hf32_to_bf16(float f) { - uint32_t u; memcpy(&u, &f, 4); return (uint16_t)(u >> 16); -} +static float hbf16_to_f32(uint16_t h) { uint32_t u = ((uint32_t)h) << 16; float f; memcpy(&f, &u, 4); return f; } +static uint16_t hf32_to_bf16(float f) { uint32_t u; memcpy(&u, &f, 4); return (uint16_t)(u >> 16); } int main() { constexpr int HD = 64; constexpr int SK = 256; const float SCALE = 1.0f / sqrtf((float)HD); - // Allocate 128B-aligned GPU memory - bf16_t *d_q, *d_k, *d_v, *d_o; - float *d_lse; - cudaMalloc(&d_q, HD * 2 + 128); - cudaMalloc(&d_k, SK * HD * 2 + 128); - cudaMalloc(&d_v, HD * SK * 2 + 128); - cudaMalloc(&d_o, HD * 2 + 128); - cudaMalloc(&d_lse, 4 + 128); + bf16_t *d_q_raw, *d_k_raw, *d_v_raw, *d_o_raw; + float *d_lse_raw; + cudaMalloc(&d_q_raw, HD * 2 + 128); + cudaMalloc(&d_k_raw, SK * HD * 2 + 128); + cudaMalloc(&d_v_raw, HD * SK * 2 + 128); + cudaMalloc(&d_o_raw, HD * 2 + 128); + cudaMalloc(&d_lse_raw, 4 + 128); - // Align pointers - d_q = (bf16_t*)(((uintptr_t)d_q + 127) & ~(uintptr_t)127); - d_k = (bf16_t*)(((uintptr_t)d_k + 127) & ~(uintptr_t)127); - d_v = (bf16_t*)(((uintptr_t)d_v + 127) & ~(uintptr_t)127); - d_o = (bf16_t*)(((uintptr_t)d_o + 127) & ~(uintptr_t)127); - d_lse = (float*)(((uintptr_t)d_lse + 127) & ~(uintptr_t)127); + bf16_t *d_q = (bf16_t*)(((uintptr_t)d_q_raw + 127) & ~(uintptr_t)127); + bf16_t *d_k = (bf16_t*)(((uintptr_t)d_k_raw + 127) & ~(uintptr_t)127); + bf16_t *d_v = (bf16_t*)(((uintptr_t)d_v_raw + 127) & ~(uintptr_t)127); + bf16_t *d_o = (bf16_t*)(((uintptr_t)d_o_raw + 127) & ~(uintptr_t)127); + float *d_lse = (float*)(((uintptr_t)d_lse_raw + 127) & ~(uintptr_t)127); - // Fill with random data srand(42); bf16_t h_q[HD], h_k[SK * HD], h_v[HD * SK]; for (int d = 0; d < HD; d++) h_q[d] = hf32_to_bf16((float)(rand() % 100) / 100.0f); @@ -51,17 +44,10 @@ int main() { cudaMemcpy(d_q, h_q, HD * 2, cudaMemcpyHostToDevice); cudaMemcpy(d_k, h_k, SK * HD * 2, cudaMemcpyHostToDevice); cudaMemcpy(d_v, h_v, HD * SK * 2, cudaMemcpyHostToDevice); - cudaMemset(d_o, 0, HD * 2); - cudaMemset(d_lse, 0, 4); - // Create TMA descriptors CUtensorMap h_tma_k, h_tma_v; - if (!create_tma_desc_2d_bf16(&h_tma_k, d_k, SK, HD, 128, 16)) { - printf("K TMA desc creation FAILED\n"); return 1; - } - if (!create_tma_desc_2d_bf16(&h_tma_v, d_v, HD, SK, 16, 16)) { - printf("V TMA desc creation FAILED\n"); return 1; - } + if (!create_tma_desc_2d_bf16(&h_tma_k, d_k, SK, HD, 128, 16)) { printf("K TMA FAILED\n"); return 1; } + if (!create_tma_desc_2d_bf16(&h_tma_v, d_v, HD, SK, 16, 16)) { printf("V TMA FAILED\n"); return 1; } CUtensorMap *d_tma_k, *d_tma_v; cudaMalloc(&d_tma_k, sizeof(CUtensorMap)); @@ -69,79 +55,67 @@ int main() { cudaMemcpy(d_tma_k, &h_tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice); cudaMemcpy(d_tma_v, &h_tma_v, sizeof(CUtensorMap), cudaMemcpyHostToDevice); - // Launch multi-tile kernel FmhaTmaMultiRowMultiTileParams params; - params.q = d_q; - params.tma_k = d_tma_k; - params.tma_v = d_tma_v; - params.o = d_o; - params.lse = d_lse; - params.s_k = SK; - params.T = 1; - params.n_h = 1; - params.scale = SCALE; - params.q_head_stride = 0; - params.q_batch_stride = 0; - params.o_head_stride = 0; - params.o_batch_stride = 0; - params.lse_head_stride = 0; - params.lse_batch_stride = 0; + params.q = d_q; params.tma_k = d_tma_k; params.tma_v = d_tma_v; + params.o = d_o; params.lse = d_lse; + params.s_k = SK; params.T = 1; params.n_h = 1; params.scale = SCALE; + params.q_head_stride = 0; params.q_batch_stride = 0; + params.o_head_stride = 0; params.o_batch_stride = 0; + params.lse_head_stride = 0; params.lse_batch_stride = 0; + + // Compute SMEM (match kernel layout) + constexpr int HD_CHUNK = 256; + constexpr int TILE_SZ = 128 * 16; + constexpr int V_SUB_SZ = 16 * 16; + int hc = (HD <= 256) ? HD : HD_CHUNK; + size_t off = 0; + off += 4; off = (off+127)&~(size_t)127; // tmembase + off += 16; off = (off+127)&~(size_t)127; // mbar + off += TILE_SZ*2; off = (off+127)&~(size_t)127; // tmabuf + off += TILE_SZ*2; off = (off+127)&~(size_t)127; // q0 + off += TILE_SZ*2; off = (off+127)&~(size_t)127; // k0 + off += TILE_SZ*2; off = (off+127)&~(size_t)127; // pk + off += V_SUB_SZ*2; off = (off+127)&~(size_t)127; // v + off += 128*hc*4; off += 128*4; off += 128*4; off += 128*4; off += 128*4; off += 256; + int smem = (int)((off + 127) & ~(size_t)127); - int smem = 4 + 16 + 128*16*2 + 128*16*2 + 128*16*2 + 128*16*2 + 16*16*2 - + 128*256*4 + 128*4 + 128*4 + 128*4 + 128*4 + 256 + 127; - smem &= ~127; dim3 grid(1, 1, 1); dim3 block(NTHREADS); - cudaFuncSetAttribute(fmha_6warp_tma_multirow_multitile_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); fmha_6warp_tma_multirow_multitile_kernel<<>>(params); cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - printf("Kernel FAILED: %s\n", cudaGetErrorString(err)); - return 1; - } + if (err != cudaSuccess) { printf("Kernel FAILED: %s\n", cudaGetErrorString(err)); return 1; } - // Compare with CPU reference - float h_o_ref[HD], h_lse_ref; - float row_max = -INFINITY; + // CPU reference + float h_o_ref[HD], row_max = -INFINITY; for (int j = 0; j < SK; j++) { - float dot = 0.0f; - for (int d = 0; d < HD; d++) dot += hbf16_to_f32(h_q[d]) * hbf16_to_f32(h_k[j * HD + d]); + float dot = 0; + for (int d = 0; d < HD; d++) dot += hbf16_to_f32(h_q[d]) * hbf16_to_f32(h_k[j*HD+d]); dot *= SCALE; if (dot > row_max) row_max = dot; } - float row_sum = 0.0f; - for (int d = 0; d < HD; d++) h_o_ref[d] = 0.0f; + float row_sum = 0; + for (int d = 0; d < HD; d++) h_o_ref[d] = 0; for (int j = 0; j < SK; j++) { - float dot = 0.0f; - for (int d = 0; d < HD; d++) dot += hbf16_to_f32(h_q[d]) * hbf16_to_f32(h_k[j * HD + d]); - dot *= SCALE; - float p = expf(dot - row_max); + float dot = 0; + for (int d = 0; d < HD; d++) dot += hbf16_to_f32(h_q[d]) * hbf16_to_f32(h_k[j*HD+d]); + float p = expf(dot * SCALE - row_max); row_sum += p; - for (int d = 0; d < HD; d++) h_o_ref[d] += p * hbf16_to_f32(h_v[d * SK + j]); + for (int d = 0; d < HD; d++) h_o_ref[d] += p * hbf16_to_f32(h_v[d*SK+j]); } for (int d = 0; d < HD; d++) h_o_ref[d] /= row_sum; - h_lse_ref = logf(row_sum) + row_max; bf16_t h_o[HD]; - float h_lse; cudaMemcpy(h_o, d_o, HD * 2, cudaMemcpyDeviceToHost); - cudaMemcpy(&h_lse, d_lse, 4, cudaMemcpyDeviceToHost); float cos = 0, na = 0, nb = 0; - for (int d = 0; d < HD; d++) { - float a = h_o_ref[d], b = hbf16_to_f32(h_o[d]); - cos += a * b; na += a * a; nb += b * b; - } + for (int d = 0; d < HD; d++) { float a = h_o_ref[d], b = hbf16_to_f32(h_o[d]); cos += a*b; na += a*a; nb += b*b; } cos /= sqrtf(na * nb + 1e-30f); - printf("Multi-tile TMA FMHA (HD=%d, SK=%d):\n", HD, SK); - printf(" LSE: kernel=%.4f ref=%.4f\n", h_lse, h_lse_ref); - printf(" Cosine: %.6f\n", cos); - printf(" %s\n", cos >= 0.999990 ? "PASS" : "FAIL"); - + printf("Multi-tile TMA FMHA (HD=%d, SK=%d): cos=%.6f %s\n", HD, SK, cos, cos >= 0.999990 ? "PASS" : "FAIL"); + cudaFree(d_q_raw); cudaFree(d_k_raw); cudaFree(d_v_raw); cudaFree(d_o_raw); cudaFree(d_lse_raw); cudaFree(d_tma_k); cudaFree(d_tma_v); return (cos >= 0.999990) ? 0 : 1; }