D1.5: Fix flat_divide slice coordinates (4 modes, no STAGE dim)

This commit is contained in:
2026-05-24 01:39:21 +00:00
parent d6a607d12e
commit cb6eae4c4f

View File

@@ -451,13 +451,13 @@ class FmhaKernel:
tOtO_epi = cute.flat_divide(tOtO_xfm, epi_tile)
tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile)
# make_tmem_copy with the first sub-tile shape
tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_epi[(None, None, 0, 0, 0)])
tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_epi[(None, None, 0, 0)])
thr_t2r = tiled_copy_t2r.get_slice(sfw_idx)
# Partition source (TMEM) and destination (GMEM-derived register shape)
tTR_tAcc = thr_t2r.partition_S(tOtO_epi)
tTR_gC = thr_t2r.partition_D(tCgC_epi)
tTR_rAcc = cute.make_rmem_tensor(
tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype)
tTR_gC[(None, None, None, 0, 0)].shape, self.acc_dtype)
# Step 3: SMEM store copy (epilogue_smem_copy_and_partition pattern)
smem_copy_atom = get_smem_store_op(