feat: per-head TMA descriptors for multi-head FMHA
This commit is contained in:
@@ -34,11 +34,11 @@ namespace dsv4::kernels::attention {
|
||||
|
||||
struct FmhaTmaMultiRowParams {
|
||||
const bf16_t* __restrict__ q;
|
||||
CUtensorMap* __restrict__ tma_k; // K: (s_k, HD) with tile (128, 16)
|
||||
CUtensorMap* __restrict__ tma_k; // Array of [n_h] TMA descriptors for K
|
||||
const bf16_t* __restrict__ v; // V: direct GMEM (HD, s_k)
|
||||
bf16_t* __restrict__ o;
|
||||
float* __restrict__ lse;
|
||||
int s_k, T;
|
||||
int s_k, T, n_h;
|
||||
float scale;
|
||||
int head_dim;
|
||||
int q_head_stride, q_batch_stride;
|
||||
@@ -111,7 +111,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
||||
int phase = 0;
|
||||
|
||||
// Row assignment
|
||||
CUtensorMap* __restrict__ my_tma_k = params.tma_k + head_idx;
|
||||
const bool my_warp_active = (T <= 32) ? (wid == 0) : is_softmax_warp;
|
||||
const int my_row = my_warp_active ? (wid * 32 + lane) : 0;
|
||||
const bool my_row_active = my_warp_active && (my_row < T);
|
||||
@@ -136,7 +136,7 @@ fmha_6warp_tma_multirow_kernel(FmhaTmaMultiRowParams params) {
|
||||
|
||||
// Load K: TMA async
|
||||
if (is_load_warp && lane == 0) {
|
||||
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)params.tma_k,
|
||||
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_k,
|
||||
mbar_addr, kt * MMA_K_BF16, 0);
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
|
||||
}
|
||||
|
||||
@@ -100,16 +100,19 @@ static int test_single(int T, int n_h = 1, int batch = 1) {
|
||||
cudaMemset(d_o, 0, total_heads*MAX_T*HD*sizeof(bf16_t));
|
||||
cudaMemset(d_lse, 0, total_heads*MAX_T*sizeof(float));
|
||||
|
||||
// TMA descriptor for K: (SK, HD) — same for all heads (for now)
|
||||
CUtensorMap tma_k; CUtensorMap* d_tma_k;
|
||||
create_tma_desc_2d_bf16(&tma_k, d_k, SK, HD, 128, 16);
|
||||
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
|
||||
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
// TMA descriptors for K: one per head
|
||||
CUtensorMap* tma_k_arr = (CUtensorMap*)malloc(total_heads * sizeof(CUtensorMap));
|
||||
CUtensorMap* d_tma_k;
|
||||
cudaMalloc(&d_tma_k, total_heads * sizeof(CUtensorMap));
|
||||
for (int h = 0; h < total_heads; h++) {
|
||||
create_tma_desc_2d_bf16(&tma_k_arr[h], d_k + h*SK*HD, SK, HD, 128, 16);
|
||||
}
|
||||
cudaMemcpy(d_tma_k, tma_k_arr, total_heads * sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
|
||||
FmhaTmaMultiRowParams params;
|
||||
params.q = d_q; params.tma_k = d_tma_k; params.v = d_v;
|
||||
params.o = d_o; params.lse = d_lse;
|
||||
params.s_k = SK; params.T = T; params.scale = SCALE;
|
||||
params.s_k = SK; params.T = T; params.scale = SCALE; params.n_h = n_h;
|
||||
params.head_dim = HD;
|
||||
params.q_head_stride = MAX_T*HD; params.q_batch_stride = n_h*MAX_T*HD;
|
||||
params.k_head_stride = SK*HD; params.k_batch_stride = n_h*SK*HD;
|
||||
@@ -157,7 +160,7 @@ static int test_single(int T, int n_h = 1, int batch = 1) {
|
||||
printf(" min_cos=%.8f %s\n", min_cos, min_cos>0.999f?"PASS":"FAIL");
|
||||
|
||||
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k);
|
||||
free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse);
|
||||
free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); free(tma_k_arr);
|
||||
|
||||
return min_cos > 0.999f ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user