From 754c6a692c4380af8f53fc8fcbb53bdc9c358653 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 29 May 2026 19:44:58 +0000 Subject: [PATCH] feat: per-head TMA descriptors for multi-head FMHA --- .../attention/fmha_6warp_tma_multirow.cuh | 8 ++++---- tests/unit/test_fmha_6warp_tma_multirow.cu | 17 ++++++++++------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh index 2ad3a8b2..8e0aae58 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multirow.cuh @@ -34,11 +34,11 @@ namespace dsv4::kernels::attention { struct FmhaTmaMultiRowParams { const bf16_t* __restrict__ q; - CUtensorMap* __restrict__ tma_k; // K: (s_k, HD) with tile (128, 16) + CUtensorMap* __restrict__ tma_k; // Array of [n_h] TMA descriptors for K const bf16_t* __restrict__ v; // V: direct GMEM (HD, s_k) bf16_t* __restrict__ o; float* __restrict__ lse; - int s_k, T; + int s_k, T, n_h; float scale; int head_dim; int q_head_stride, q_batch_stride; @@ -111,7 +111,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); int phase = 0; - // Row assignment + CUtensorMap* __restrict__ my_tma_k = params.tma_k + head_idx; const bool my_warp_active = (T <= 32) ? (wid == 0) : is_softmax_warp; const int my_row = my_warp_active ? (wid * 32 + lane) : 0; const bool my_row_active = my_warp_active && (my_row < T); @@ -136,7 +136,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) { // Load K: TMA async if (is_load_warp && lane == 0) { - tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)params.tma_k, + tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k, mbar_addr, kt * MMA_K_BF16, 0); tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES); } diff --git a/tests/unit/test_fmha_6warp_tma_multirow.cu b/tests/unit/test_fmha_6warp_tma_multirow.cu index aa3035b9..bc589ab0 100644 --- a/tests/unit/test_fmha_6warp_tma_multirow.cu +++ b/tests/unit/test_fmha_6warp_tma_multirow.cu @@ -100,16 +100,19 @@ static int test_single(int T, int n_h = 1, int batch = 1) { cudaMemset(d_o, 0, total_heads*MAX_T*HD*sizeof(bf16_t)); cudaMemset(d_lse, 0, total_heads*MAX_T*sizeof(float)); - // TMA descriptor for K: (SK, HD) — same for all heads (for now) - CUtensorMap tma_k; CUtensorMap* d_tma_k; - create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, 128, 16); - cudaMalloc(&d_tma_k, sizeof(CUtensorMap)); - cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + // TMA descriptors for K: one per head + CUtensorMap* tma_k_arr = (CUtensorMap*)malloc(total_heads * sizeof(CUtensorMap)); + CUtensorMap* d_tma_k; + cudaMalloc(&d_tma_k, total_heads * sizeof(CUtensorMap)); + for (int h = 0; h < total_heads; h++) { + create_tma_desc_2d_bf16(&tma_k_arr[h], d_k + h*SK*HD, SK, HD, 128, 16); + } + cudaMemcpy(d_tma_k, tma_k_arr, total_heads * sizeof(CUtensorMap), cudaMemcpyHostToDevice); FmhaTmaMultiRowParams params; params.q = d_q; params.tma_k = d_tma_k; params.v = d_v; params.o = d_o; params.lse = d_lse; - params.s_k = SK; params.T = T; params.scale = SCALE; + params.s_k = SK; params.T = T; params.scale = SCALE; params.n_h = n_h; params.head_dim = HD; params.q_head_stride = MAX_T*HD; params.q_batch_stride = n_h*MAX_T*HD; params.k_head_stride = SK*HD; params.k_batch_stride = n_h*SK*HD; @@ -157,7 +160,7 @@ static int test_single(int T, int n_h = 1, int batch = 1) { printf(" min_cos=%.8f %s\n", min_cos, min_cos>0.999f?"PASS":"FAIL"); cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k); - free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); + free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); free(tma_k_arr); return min_cos > 0.999f ? 0 : 1; }