From ec250eccd6a67683dcd3b19032e6d25d2c4e2f79 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 01:36:01 +0000 Subject: [PATCH] D1.5: Fix TMA store - group_modes on bSG_gC, use flat indexing --- dsv4/kernels/attention/fmha.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 4ce2b669..a2188329 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -500,9 +500,11 @@ class FmhaKernel: cute.group_modes(sC, 0, 2), cute.group_modes(tCgC_epi, 0, 2), ) + # Group all modes >= 1 into one (CUTLASS pattern) + bSG_gC_flat = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) # One TMA store for the full output tile if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC[(None, None, None, Int32(0), Int32(0), Int32(0))]) + cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC_flat[(None, Int32(0))]) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True)