From cb6eae4c4f9ba3fcc3f121b8577e67df0310df24 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 01:39:21 +0000 Subject: [PATCH] D1.5: Fix flat_divide slice coordinates (4 modes, no STAGE dim) --- dsv4/kernels/attention/fmha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 82bcc044..6c6127ac 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -451,13 +451,13 @@ class FmhaKernel: tOtO_epi = cute.flat_divide(tOtO_xfm, epi_tile) tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile) # make_tmem_copy with the first sub-tile shape - tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_epi[(None, None, 0, 0, 0)]) + tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_epi[(None, None, 0, 0)]) thr_t2r = tiled_copy_t2r.get_slice(sfw_idx) # Partition source (TMEM) and destination (GMEM-derived register shape) tTR_tAcc = thr_t2r.partition_S(tOtO_epi) tTR_gC = thr_t2r.partition_D(tCgC_epi) tTR_rAcc = cute.make_rmem_tensor( - tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) + tTR_gC[(None, None, None, 0, 0)].shape, self.acc_dtype) # Step 3: SMEM store copy (epilogue_smem_copy_and_partition pattern) smem_copy_atom = get_smem_store_op(