fix: partition_A not partition_S

This commit is contained in:
2026-05-23 03:47:53 +00:00
parent 3ee330a84c
commit 469665f69a

View File

@@ -200,9 +200,9 @@ class FmhaKernel:
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P → SMEM copy (using PV A-operand thread partition)
# P → SMEM: use PV A-operand partition for SMEM write
p_s = cute.slice_(p_smem_s,(None,None,None,0))
tCrP_smem = pv_thr.partition_S(sP) # softmax thread → SMEM partition for P
tCrP_smem = pv_thr.partition_A(sP) # PV thread → SMEM partition for P (A operand)
tCrP_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype)
# Online softmax state