D1: P store uses tOrP0.layout (PV A-fragment TMEM layout)
This commit is contained in:
@@ -225,15 +225,17 @@ class FmhaKernel:
|
||||
|
||||
# P store atoms
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
# P store must use the PV A-fragment layout (p_tmem_s.outer), not the QK C-fragment layout.
|
||||
# At hd=64 these match by coincidence; at hd>64 they diverge, causing garbage PV output.
|
||||
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, p_tmem_s.outer)
|
||||
# P store must use the PV A-fragment layout so PV reads correct TMEM columns.
|
||||
# Use tOrP0's layout (PV A-fragment) for the store target, not QK C-fragment composition.
|
||||
tStP0 = cute.make_tensor(tOrP0.iterator, tOrP0.layout)
|
||||
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_store.partition_D(tStP0)
|
||||
tScP = cute.make_tensor(tScS.iterator, p_tmem_s.outer)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
# tScP: coordinate tensor for P store — use PV A-fragment layout
|
||||
cP = cute.make_identity_tensor((self.pv_mma_tiler[0], p_cols_fp32))
|
||||
tOcP = pv_thr.partition_A(cP) # partition using PV thread slice
|
||||
tTMEM_STOREcP = thr_store.partition_S(cute.make_tensor(tOcP.iterator, tOrP0.layout))
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
|
||||
Reference in New Issue
Block a user