From cf080ccf0074b7f1813d208515b1b116f883d3ee Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 03:54:49 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20cpasync.CopyOp=20for=20reg=E2=86=92SMEM?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/attention/fmha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index f3d5572b..3bad0704 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -200,8 +200,8 @@ class FmhaKernel: tScS = qk_thr.partition_C(cS) tTMEM_LOADcS = thr_load.partition_D(tScS) - # P → SMEM copy setup (use SMEM copy atom for register→SMEM) - p_copy_atom = cute.make_copy_atom(cute.CopyAtomUniversalOp(), self.q_dtype) + # P → SMEM copy setup + p_copy_atom = cute.make_copy_atom(cpasync.CopyOp(), self.q_dtype) tiled_p_copy = cute.make_tiled_copy(p_copy_atom, tCrP_smem.layout, tCrP_reg.layout, tidx) tPS_sP = tiled_p_copy.get_slice(sfw_idx).partition_D(tCrP_smem) tPS_rP = tiled_p_copy.get_slice(sfw_idx).partition_S(tCrP_reg)