From 699c646497a2ba9f2b1ba00a148bcf8b26dcb1d6 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 01:41:52 +0000 Subject: [PATCH] D1.5: Fix bSG_gC slicing - group trailing modes (CUTLASS pattern) --- dsv4/kernels/attention/fmha.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index be2b6085..aeb5d93f 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -512,8 +512,9 @@ class FmhaKernel: # TMA store SMEM → GMEM if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, bSG_sC[(None, c_buffer)], - bSG_gC[(None, None, None, Int32(0), Int32(0), Int32(0))]) + # Group trailing modes and slice (CUTLASS pattern) + bSG_gC_flat = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC_flat[(None, Int32(0))]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) corr_epi_bar.arrive_and_wait()