fix: partition_A not partition_S
This commit is contained in:
@@ -200,9 +200,9 @@ class FmhaKernel:
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# P → SMEM copy (using PV A-operand thread partition)
|
||||
# P → SMEM: use PV A-operand partition for SMEM write
|
||||
p_s = cute.slice_(p_smem_s,(None,None,None,0))
|
||||
tCrP_smem = pv_thr.partition_S(sP) # softmax thread → SMEM partition for P
|
||||
tCrP_smem = pv_thr.partition_A(sP) # PV thread → SMEM partition for P (A operand)
|
||||
tCrP_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype)
|
||||
|
||||
# Online softmax state
|
||||
|
||||
Reference in New Issue
Block a user