diff --git a/dsv4/kernels/attention/fmha_6warp_tma.cuh b/dsv4/kernels/attention/fmha_6warp_tma.cuh index df3f5d69..ad4f4336 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma.cuh @@ -49,8 +49,7 @@ __global__ void __launch_bounds__(192) fmha_6warp_tma_kernel( const bf16_t* __restrict__ q, CUtensorMap* __restrict__ tma_k, - const bf16_t* __restrict__ v, - CUtensorMap* __restrict__ tma_v_unused, + CUtensorMap* __restrict__ tma_v, bf16_t* __restrict__ o, float* __restrict__ lse, int s_k, @@ -207,25 +206,33 @@ fmha_6warp_tma_kernel( int d_base = n * 16; for (int kt = 0; kt < NKT_PV; kt++) { - // ---- Warp 5: Fill sPk and load V sub-tile (direct) ---- + // ---- Fill sPk ---- if (is_load_warp) { - // Fill sPk from s_p_vals for (int i = lane; i < TILE_SZ; i += 32) sPk[i] = 0; if (lane < 16) { int c = lane; int ck = c / 8, lc = c % 8; sPk[ck * CORES_MN * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(s_p_vals[kt * MMA_K_BF16 + c]); } - // Load V sub-tile: direct from GMEM - for (int i = lane; i < V_SUB_SZ; i += 32) sV[i] = 0; - for (int dd = lane; dd < 16; dd += 32) { - for (int lr = 0; lr < MMA_K_BF16; lr++) { - int r = kt * MMA_K_BF16 + lr; - int g_mn = dd / 8, g_k = lr / 8; - int llr = dd % 8, lc = lr % 8; - sV[g_k * 2 * 64 + g_mn * 64 + llr * 8 + lc] = v[(d_base + dd) * SK_TILE + r]; - } - } + } + __syncthreads(); + + // ---- Load V via TMA ---- + if (is_load_warp && lane == 0) { + tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_v, + mbar_addr, kt * MMA_K_BF16, d_base); + tma_mbarrier_arrive_expect_tx(mbar_addr, V_SUB_SZ * sizeof(bf16_t)); + } + tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; + __syncthreads(); + + // Convert sTmaBuf → canonical sV + for (int i = tid; i < V_SUB_SZ; i += 192) sV[i] = 0; + for (int i = tid; i < 16 * MMA_K_BF16; i += 192) { + int dd = i / MMA_K_BF16, lr = i % MMA_K_BF16; + int g_mn = dd/8, g_k = lr/8, llr = dd%8, lc = lr%8; + sV[g_k*2*64 + g_mn*64 + llr*8 + lc] = sTmaBuf[i]; + } } __syncthreads(); diff --git a/tests/unit/test_fmha_6warp_tma.cu b/tests/unit/test_fmha_6warp_tma.cu index 6e0979a4..d5a4f8e3 100644 --- a/tests/unit/test_fmha_6warp_tma.cu +++ b/tests/unit/test_fmha_6warp_tma.cu @@ -122,7 +122,7 @@ int main() { if (smem > 48 * 1024) { cudaFuncSetAttribute(fmha_6warp_tma_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem); } - fmha_6warp_tma_kernel<<<1, 192, smem>>>(d_q, d_tma_k, d_v, d_tma_v, d_o, d_lse, SK, SCALE); + fmha_6warp_tma_kernel<<<1, 192, smem>>>(d_q, d_tma_k, d_tma_v, d_o, d_lse, SK, SCALE); cudaError_t launch_err = cudaGetLastError(); if (launch_err != cudaSuccess) { printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err)); return 1; }