feat: per-head TMA descriptors for multi-head FMHA
This commit is contained in:
@@ -100,16 +100,19 @@ static int test_single(int T, int n_h = 1, int batch = 1) {
|
||||
cudaMemset(d_o, 0, total_heads*MAX_T*HD*sizeof(bf16_t));
|
||||
cudaMemset(d_lse, 0, total_heads*MAX_T*sizeof(float));
|
||||
|
||||
// TMA descriptor for K: (SK, HD) — same for all heads (for now)
|
||||
CUtensorMap tma_k; CUtensorMap* d_tma_k;
|
||||
create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, 128, 16);
|
||||
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
|
||||
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
// TMA descriptors for K: one per head
|
||||
CUtensorMap* tma_k_arr = (CUtensorMap*)malloc(total_heads * sizeof(CUtensorMap));
|
||||
CUtensorMap* d_tma_k;
|
||||
cudaMalloc(&d_tma_k, total_heads * sizeof(CUtensorMap));
|
||||
for (int h = 0; h < total_heads; h++) {
|
||||
create_tma_desc_2d_bf16(&tma_k_arr[h], d_k + h*SK*HD, SK, HD, 128, 16);
|
||||
}
|
||||
cudaMemcpy(d_tma_k, tma_k_arr, total_heads * sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
|
||||
FmhaTmaMultiRowParams params;
|
||||
params.q = d_q; params.tma_k = d_tma_k; params.v = d_v;
|
||||
params.o = d_o; params.lse = d_lse;
|
||||
params.s_k = SK; params.T = T; params.scale = SCALE;
|
||||
params.s_k = SK; params.T = T; params.scale = SCALE; params.n_h = n_h;
|
||||
params.head_dim = HD;
|
||||
params.q_head_stride = MAX_T*HD; params.q_batch_stride = n_h*MAX_T*HD;
|
||||
params.k_head_stride = SK*HD; params.k_batch_stride = n_h*SK*HD;
|
||||
@@ -157,7 +160,7 @@ static int test_single(int T, int n_h = 1, int batch = 1) {
|
||||
printf(" min_cos=%.8f %s\n", min_cos, min_cos>0.999f?"PASS":"FAIL");
|
||||
|
||||
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k);
|
||||
free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse);
|
||||
free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); free(tma_k_arr);
|
||||
|
||||
return min_cos > 0.999f ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user