fix: PV accumulate flag

This commit is contained in:
2026-05-30 06:56:53 +00:00
parent 1da785c070
commit 25aeaca9ab

View File

@@ -278,8 +278,7 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
// TMEM column offset: (n_sub - n_sub_start) * 16
int tmem_col = (n_sub - n_sub_start) * 16;
bool accumulate = (pv_kt > 0) || (n_sub > n_sub_start);
if (tid == 128) umma_ss_f16(tb + tmem_col, dp, dv, idesc_pv, accumulate);
if (tid == 128) umma_ss_f16(tb + tmem_col, dp, dv, idesc_pv, pv_kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();