diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 0cf1ca0f..79e3d645 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -250,7 +250,7 @@ class FmhaKernel: for j in cutlass.range(cute.size(rP_bf16_reg), vectorize=True): # TODO: proper element mapping from QK→PV partition rP_bf16_reg[j] = BFloat16(0.0) - cute.copy(rP_bf16_reg, tCrP_smem) + cute.copy(tCrP_smem, rP_bf16_reg) cute.arch.fence_proxy("async.shared", space="cta") si_handle.release()