diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multitile.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multitile.cuh index c9b87dcc..21938e0c 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multitile.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multitile.cuh @@ -41,7 +41,8 @@ namespace dsv4::kernels::attention { struct FmhaTmaMultiTileParams { const bf16_t* __restrict__ q; CUtensorMap* __restrict__ tma_k; // Array of [n_h] TMA descriptors for K: (s_k, HD) tile (128,16) - const bf16_t* __restrict__ v; // (HD, s_k) + CUtensorMap* __restrict__ tma_v; // Array of [n_h] TMA descriptors for V: (HD, s_k) tile (16,16) + const bf16_t* __restrict__ v; // (HD, s_k) — fallback direct GMEM bf16_t* __restrict__ o; float* __restrict__ lse; int s_k, n_h; @@ -81,6 +82,7 @@ fmha_6warp_tma_multitile_kernel(FmhaTmaMultiTileParams params) { bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride; float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr; CUtensorMap* __restrict__ my_tma_k = params.tma_k + batch_idx * params.n_h + head_idx; + CUtensorMap* __restrict__ my_tma_v = params.tma_v + batch_idx * params.n_h + head_idx; // ================================================================ // SMEM allocation @@ -224,17 +226,25 @@ fmha_6warp_tma_multitile_kernel(FmhaTmaMultiTileParams params) { int ck = c/8, lc = c%8; sPk[ck*CORES_MN*64 + 0*64 + 0*8 + lc] = f32_to_bf16(s_p_vals[col_start + c]); } - // Load V - for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0; - for (int dd = lane; dd < 16; dd += 32) { - for (int lr = 0; lr < MMA_K_BF16; lr++) { - int r = abs_col + lr; - if (r < s_k && (d_base+dd) < HD) { - int g_mn = dd/8, g_k = lr/8, llr = dd%8, lc = lr%8; - sV[g_k*2*64 + g_mn*64 + llr*8 + lc] = v_head[(d_base+dd)*s_k + r]; - } - } - } + } + __syncthreads(); + + // Load V via TMA: (16, 16) tile at (abs_col, d_base) + // V is (HD, s_k). TMA coord: (innermost=abs_col, outermost=d_base) + if (is_load_warp && lane == 0) { + tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)my_tma_v, + mbar_addr, abs_col, d_base); + tma_mbarrier_arrive_expect_tx(mbar_addr, V_SUB_SZ * sizeof(bf16_t)); + } + tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; + __syncthreads(); + + // Convert sTmaBuf (row-major 16×16) → sV (canonical 16×16) + for (int i = tid; i < V_SUB_SZ; i += 192) sV[i] = 0; + for (int i = tid; i < 16 * MMA_K_BF16; i += 192) { + int dd = i / MMA_K_BF16, lr = i % MMA_K_BF16; + int g_mn = dd/8, g_k = lr/8, llr = dd%8, lc = lr%8; + sV[g_k*2*64 + g_mn*64 + llr*8 + lc] = sTmaBuf[i]; } __syncthreads(); diff --git a/tests/unit/test_fmha_6warp_tma_multitile.cu b/tests/unit/test_fmha_6warp_tma_multitile.cu index e3494c1f..cc5cfb89 100644 --- a/tests/unit/test_fmha_6warp_tma_multitile.cu +++ b/tests/unit/test_fmha_6warp_tma_multitile.cu @@ -104,8 +104,13 @@ int main() { cudaMalloc(&d_tma_k, sizeof(CUtensorMap)); cudaMemcpy(d_tma_k, &tma_k, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + CUtensorMap tma_v; CUtensorMap* d_tma_v; + create_tma_desc_2d_bf16(&tma_v, d_v, HD, s_k, 16, 16); + cudaMalloc(&d_tma_v, sizeof(CUtensorMap)); + cudaMemcpy(d_tma_v, &tma_v, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + FmhaTmaMultiTileParams params; - params.q = d_q; params.tma_k = d_tma_k; params.v = d_v; + params.q = d_q; params.tma_k = d_tma_k; params.tma_v = d_tma_v; params.v = d_v; params.o = d_o; params.lse = d_lse; params.s_k = s_k; params.n_h = 1; params.scale = SCALE; params.q_head_stride = HD; params.q_batch_stride = HD; @@ -144,7 +149,7 @@ int main() { printf(" ref[0..3]: "); for(int d=0;d<4;d++) printf("%.6f ", o_ref[d]); printf("\n"); } - cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k); + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); cudaFree(d_tma_k); cudaFree(d_tma_v); free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); free(o_ref); }