D1.5: Fix TMA store - use group_modes on sC and tCgC
This commit is contained in:
@@ -492,11 +492,13 @@ class FmhaKernel:
|
||||
# The normalized O is now in sC (written by the correction epilog).
|
||||
# Use the same TMA store pattern as the CUTLASS FMHA reference.
|
||||
# Partition sC and gC for the bulk TMA copy.
|
||||
sC_epi = cute.select(sC, mode=[0, 1]) # 2D view for TMA
|
||||
gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))) # 2D output tile
|
||||
# Get 2D views of sC and gC for TMA.
|
||||
sC_epi = sC # sC already has c_smem_s layout, TMA can handle it
|
||||
gC_epi = tCgC # Use the pre-partitioned GMEM tensor
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
sC_epi, gC_epi,
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(gC_epi, 0, 2),
|
||||
)
|
||||
# One TMA store for the full output tile
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
|
||||
Reference in New Issue
Block a user