fix: cpasync.CopyOp for reg→SMEM

This commit is contained in:
2026-05-23 03:54:49 +00:00
parent a18c639021
commit cf080ccf00

View File

@@ -200,8 +200,8 @@ class FmhaKernel:
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P → SMEM copy setup (use SMEM copy atom for register→SMEM)
p_copy_atom = cute.make_copy_atom(cute.CopyAtomUniversalOp(), self.q_dtype)
# P → SMEM copy setup
p_copy_atom = cute.make_copy_atom(cpasync.CopyOp(), self.q_dtype)
tiled_p_copy = cute.make_tiled_copy(p_copy_atom, tCrP_smem.layout, tCrP_reg.layout, tidx)
tPS_sP = tiled_p_copy.get_slice(sfw_idx).partition_D(tCrP_smem)
tPS_rP = tiled_p_copy.get_slice(sfw_idx).partition_S(tCrP_reg)