feat: per-head TMA descriptors for multi-head FMHA

This commit is contained in:
2026-05-29 19:44:58 +00:00
parent 9eb193458e
commit 754c6a692c
2 changed files with 14 additions and 11 deletions

View File

@@ -34,11 +34,11 @@ namespace dsv4::kernels::attention {
struct FmhaTmaMultiRowParams {
const bf16_t* __restrict__ q;
CUtensorMap* __restrict__ tma_k; // K: (s_k, HD) with tile (128, 16)
CUtensorMap* __restrict__ tma_k; // Array of [n_h] TMA descriptors for K
const bf16_t* __restrict__ v; // V: direct GMEM (HD, s_k)
bf16_t* __restrict__ o;
float* __restrict__ lse;
int s_k, T;
int s_k, T, n_h;
float scale;
int head_dim;
int q_head_stride, q_batch_stride;
@@ -111,7 +111,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
int phase = 0;
// Row assignment
CUtensorMap* __restrict__ my_tma_k = params.tma_k + head_idx;
const bool my_warp_active = (T <= 32) ? (wid == 0) : is_softmax_warp;
const int my_row = my_warp_active ? (wid * 32 + lane) : 0;
const bool my_row_active = my_warp_active && (my_row < T);
@@ -136,7 +136,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
// Load K: TMA async
if (is_load_warp && lane == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)params.tma_k,
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k,
mbar_addr, kt * MMA_K_BF16, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
}

View File

@@ -100,16 +100,19 @@ static int test_single(int T, int n_h = 1, int batch = 1) {
cudaMemset(d_o, 0, total_heads*MAX_T*HD*sizeof(bf16_t));
cudaMemset(d_lse, 0, total_heads*MAX_T*sizeof(float));
// TMA descriptor for K: (SK, HD) — same for all heads (for now)
CUtensorMap tma_k; CUtensorMap* d_tma_k;
create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, 128, 16);
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
// TMA descriptors for K: one per head
CUtensorMap* tma_k_arr = (CUtensorMap*)malloc(total_heads * sizeof(CUtensorMap));
CUtensorMap* d_tma_k;
cudaMalloc(&d_tma_k, total_heads * sizeof(CUtensorMap));
for (int h = 0; h < total_heads; h++) {
create_tma_desc_2d_bf16(&tma_k_arr[h], d_k + h*SK*HD, SK, HD, 128, 16);
}
cudaMemcpy(d_tma_k, tma_k_arr, total_heads * sizeof(CUtensorMap), cudaMemcpyHostToDevice);
FmhaTmaMultiRowParams params;
params.q = d_q; params.tma_k = d_tma_k; params.v = d_v;
params.o = d_o; params.lse = d_lse;
params.s_k = SK; params.T = T; params.scale = SCALE;
params.s_k = SK; params.T = T; params.scale = SCALE; params.n_h = n_h;
params.head_dim = HD;
params.q_head_stride = MAX_T*HD; params.q_batch_stride = n_h*MAX_T*HD;
params.k_head_stride = SK*HD; params.k_batch_stride = n_h*SK*HD;
@@ -157,7 +160,7 @@ static int test_single(int T, int n_h = 1, int batch = 1) {
printf(" min_cos=%.8f %s\n", min_cos, min_cos>0.999f?"PASS":"FAIL");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k);
free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse);
free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); free(tma_k_arr);
return min_cos > 0.999f ? 0 : 1;
}