diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 3bad0704..9624dcab 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -200,11 +200,12 @@ class FmhaKernel: tScS = qk_thr.partition_C(cS) tTMEM_LOADcS = thr_load.partition_D(tScS) - # P → SMEM copy setup - p_copy_atom = cute.make_copy_atom(cpasync.CopyOp(), self.q_dtype) - tiled_p_copy = cute.make_tiled_copy(p_copy_atom, tCrP_smem.layout, tCrP_reg.layout, tidx) - tPS_sP = tiled_p_copy.get_slice(sfw_idx).partition_D(tCrP_smem) - tPS_rP = tiled_p_copy.get_slice(sfw_idx).partition_S(tCrP_reg) + # P → SMEM: use make_tiled_copy_C for register→SMEM (standard epilogue pattern) + # The P values are the A operand of PV, written to SMEM so the MMA can read them + p_s = cute.slice_(p_smem_s,(None,None,None,0)) + tCrP_smem = pv_thr.partition_A(sP) # PV thread → SMEM partition for P (A operand) + tCrP_reg = pv_mma.make_fragment_A(sP) # register fragment matching SMEM layout + tiled_p_copy = cute.make_tiled_copy_C(pv_mma, tCrP_smem, p_s, 1) # Online softmax state row_max = -Float32.inf @@ -240,11 +241,11 @@ class FmhaKernel: tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) row_sum = row_sum + tTMEM_LOADrS_frg[k, j] - # Write P to SMEM using PV A-operand thread partition + # Write P to SMEM using PV A-operand partition # TODO: proper element mapping from QK→PV partition for j in cutlass.range(cute.size(tCrP_reg), vectorize=True): tCrP_reg[j] = BFloat16(0.0) - cute.copy(tiled_p_copy, tPS_rP, tPS_sP) + cute.copy(tiled_p_copy, tCrP_reg, tCrP_smem) cute.arch.fence_proxy("async.shared", space="cta") si_handle.release()