fix: cpasync.CopyOp for reg→SMEM
This commit is contained in:
@@ -200,8 +200,8 @@ class FmhaKernel:
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# P → SMEM copy setup (use SMEM copy atom for register→SMEM)
|
||||
p_copy_atom = cute.make_copy_atom(cute.CopyAtomUniversalOp(), self.q_dtype)
|
||||
# P → SMEM copy setup
|
||||
p_copy_atom = cute.make_copy_atom(cpasync.CopyOp(), self.q_dtype)
|
||||
tiled_p_copy = cute.make_tiled_copy(p_copy_atom, tCrP_smem.layout, tCrP_reg.layout, tidx)
|
||||
tPS_sP = tiled_p_copy.get_slice(sfw_idx).partition_D(tCrP_smem)
|
||||
tPS_rP = tiled_p_copy.get_slice(sfw_idx).partition_S(tCrP_reg)
|
||||
|
||||
Reference in New Issue
Block a user