Fix sP_2d definition for tSMEM_CPYsP

This commit is contained in:
2026-05-23 09:34:50 +00:00
parent ffafd47d07
commit 7a74fac11f

View File

@@ -268,6 +268,7 @@ class FmhaKernel:
)
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx)
sP_2d = cute.group_modes(sP, 0, 3)
tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) # destination (SMEM)
row_max = -Float32.inf