diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 787cf426..ea02f698 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -200,9 +200,9 @@ class FmhaKernel: tScS = qk_thr.partition_C(cS) tTMEM_LOADcS = thr_load.partition_D(tScS) - # P → SMEM copy (using PV A-operand thread partition) + # P → SMEM: use PV A-operand partition for SMEM write p_s = cute.slice_(p_smem_s,(None,None,None,0)) - tCrP_smem = pv_thr.partition_S(sP) # softmax thread → SMEM partition for P + tCrP_smem = pv_thr.partition_A(sP) # PV thread → SMEM partition for P (A operand) tCrP_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype) # Online softmax state