feat: per-head TMA descriptors for multi-head FMHA

This commit is contained in:
2026-05-29 19:44:58 +00:00
parent 9eb193458e
commit 754c6a692c
2 changed files with 14 additions and 11 deletions

View File

@@ -100,16 +100,19 @@ static int test_single(int T, int n_h = 1, int batch = 1) {
cudaMemset(d_o, 0, total_heads*MAX_T*HD*sizeof(bf16_t));
cudaMemset(d_lse, 0, total_heads*MAX_T*sizeof(float));
// TMA descriptor for K: (SK, HD) — same for all heads (for now)
CUtensorMap tma_k; CUtensorMap* d_tma_k;
create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, 128, 16);
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
// TMA descriptors for K: one per head
CUtensorMap* tma_k_arr = (CUtensorMap*)malloc(total_heads * sizeof(CUtensorMap));
CUtensorMap* d_tma_k;
cudaMalloc(&d_tma_k, total_heads * sizeof(CUtensorMap));
for (int h = 0; h < total_heads; h++) {
create_tma_desc_2d_bf16(&tma_k_arr[h], d_k + h*SK*HD, SK, HD, 128, 16);
}
cudaMemcpy(d_tma_k, tma_k_arr, total_heads * sizeof(CUtensorMap), cudaMemcpyHostToDevice);
FmhaTmaMultiRowParams params;
params.q = d_q; params.tma_k = d_tma_k; params.v = d_v;
params.o = d_o; params.lse = d_lse;
params.s_k = SK; params.T = T; params.scale = SCALE;
params.s_k = SK; params.T = T; params.scale = SCALE; params.n_h = n_h;
params.head_dim = HD;
params.q_head_stride = MAX_T*HD; params.q_batch_stride = n_h*MAX_T*HD;
params.k_head_stride = SK*HD; params.k_batch_stride = n_h*SK*HD;
@@ -157,7 +160,7 @@ static int test_single(int T, int n_h = 1, int batch = 1) {
printf(" min_cos=%.8f %s\n", min_cos, min_cos>0.999f?"PASS":"FAIL");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k);
free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse);
free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); free(tma_k_arr);
return min_cos > 0.999f ? 0 : 1;
}