fix: PV accumulate flag
This commit is contained in:
@@ -278,8 +278,7 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
|
||||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||||
// TMEM column offset: (n_sub - n_sub_start) * 16
|
||||
int tmem_col = (n_sub - n_sub_start) * 16;
|
||||
bool accumulate = (pv_kt > 0) || (n_sub > n_sub_start);
|
||||
if (tid == 128) umma_ss_f16(tb + tmem_col, dp, dv, idesc_pv, accumulate);
|
||||
if (tid == 128) umma_ss_f16(tb + tmem_col, dp, dv, idesc_pv, pv_kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
Reference in New Issue
Block a user