feat: V TMA loads in multi-tile kernel

This commit is contained in:
2026-05-29 22:46:21 +00:00
parent 680d2ebf64
commit 74145a31cc
2 changed files with 29 additions and 14 deletions

View File

@@ -41,7 +41,8 @@ namespace dsv4::kernels::attention {
struct FmhaTmaMultiTileParams {
const bf16_t* __restrict__ q;
CUtensorMap* __restrict__ tma_k; // Array of [n_h] TMA descriptors for K: (s_k, HD) tile (128,16)
const bf16_t* __restrict__ v; // (HD, s_k)
CUtensorMap* __restrict__ tma_v; // Array of [n_h] TMA descriptors for V: (HD, s_k) tile (16,16)
const bf16_t* __restrict__ v; // (HD, s_k) — fallback direct GMEM
bf16_t* __restrict__ o;
float* __restrict__ lse;
int s_k, n_h;
@@ -81,6 +82,7 @@ fmha_6warp_tma_multitile_kernel(FmhaTmaMultiTileParams params) {
bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride;
float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr;
CUtensorMap* __restrict__ my_tma_k = params.tma_k + batch_idx * params.n_h + head_idx;
CUtensorMap* __restrict__ my_tma_v = params.tma_v + batch_idx * params.n_h + head_idx;
// ================================================================
// SMEM allocation
@@ -224,17 +226,25 @@ fmha_6warp_tma_multitile_kernel(FmhaTmaMultiTileParams params) {
int ck = c/8, lc = c%8;
sPk[ck*CORES_MN*64 + 0*64 + 0*8 + lc] = f32_to_bf16(s_p_vals[col_start + c]);
}
// Load V
for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0;
for (int dd = lane; dd < 16; dd += 32) {
for (int lr = 0; lr < MMA_K_BF16; lr++) {
int r = abs_col + lr;
if (r < s_k && (d_base+dd) < HD) {
int g_mn = dd/8, g_k = lr/8, llr = dd%8, lc = lr%8;
sV[g_k*2*64 + g_mn*64 + llr*8 + lc] = v_head[(d_base+dd)*s_k + r];
}
}
}
}
__syncthreads();
// Load V via TMA: (16, 16) tile at (abs_col, d_base)
// V is (HD, s_k). TMA coord: (innermost=abs_col, outermost=d_base)
if (is_load_warp && lane == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_v,
mbar_addr, abs_col, d_base);
tma_mbarrier_arrive_expect_tx(mbar_addr, V_SUB_SZ * sizeof(bf16_t));
}
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
__syncthreads();
// Convert sTmaBuf (row-major 16×16) → sV (canonical 16×16)
for (int i = tid; i < V_SUB_SZ; i += 192) sV[i] = 0;
for (int i = tid; i < 16 * MMA_K_BF16; i += 192) {
int dd = i / MMA_K_BF16, lr = i % MMA_K_BF16;
int g_mn = dd/8, g_k = lr/8, llr = dd%8, lc = lr%8;
sV[g_k*2*64 + g_mn*64 + llr*8 + lc] = sTmaBuf[i];
}
__syncthreads();

View File

@@ -104,8 +104,13 @@ int main() {
cudaMalloc(&d_tma_k, sizeof(CUtensorMap));
cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
CUtensorMap tma_v; CUtensorMap* d_tma_v;
create_tma_desc_2d_bf16(&tma_v, d_v, HD, s_k, 16, 16);
cudaMalloc(&d_tma_v, sizeof(CUtensorMap));
cudaMemcpy(d_tma_v, &tma_v, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
FmhaTmaMultiTileParams params;
params.q = d_q; params.tma_k = d_tma_k; params.v = d_v;
params.q = d_q; params.tma_k = d_tma_k; params.tma_v = d_tma_v; params.v = d_v;
params.o = d_o; params.lse = d_lse;
params.s_k = s_k; params.n_h = 1; params.scale = SCALE;
params.q_head_stride = HD; params.q_batch_stride = HD;
@@ -144,7 +149,7 @@ int main() {
printf(" ref[0..3]: "); for(int d=0;d<4;d++) printf("%.6f ", o_ref[d]); printf("\n");
}
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k);
cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k); cudaFree(d_tma_v);
free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); free(o_ref);
}