D1.5: Fix flat_divide slice coordinates (4 modes, no STAGE dim)
This commit is contained in:
@@ -451,13 +451,13 @@ class FmhaKernel:
|
||||
tOtO_epi = cute.flat_divide(tOtO_xfm, epi_tile)
|
||||
tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile)
|
||||
# make_tmem_copy with the first sub-tile shape
|
||||
tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_epi[(None, None, 0, 0, 0)])
|
||||
tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_epi[(None, None, 0, 0)])
|
||||
thr_t2r = tiled_copy_t2r.get_slice(sfw_idx)
|
||||
# Partition source (TMEM) and destination (GMEM-derived register shape)
|
||||
tTR_tAcc = thr_t2r.partition_S(tOtO_epi)
|
||||
tTR_gC = thr_t2r.partition_D(tCgC_epi)
|
||||
tTR_rAcc = cute.make_rmem_tensor(
|
||||
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype)
|
||||
tTR_gC[(None, None, None, 0, 0)].shape, self.acc_dtype)
|
||||
|
||||
# Step 3: SMEM store copy (epilogue_smem_copy_and_partition pattern)
|
||||
smem_copy_atom = get_smem_store_op(
|
||||
|
||||
Reference in New Issue
Block a user