SMEM-P: try using PV A-operand layout directly for TMEM-P
This commit is contained in:
@@ -253,15 +253,17 @@ class FmhaKernel:
|
||||
|
||||
# P store atoms: TMEM-P (always defined, only used when use_smem_p=False)
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
# Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid)
|
||||
tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout)
|
||||
# Try using PV A-operand layout directly (p_tmem_s.outer) instead of composed layout
|
||||
# tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
# Use PV A-operand TMEM layout
|
||||
tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), p_tmem_s.outer)
|
||||
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_store.partition_D(tStP0)
|
||||
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
# tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
# Use same PV A-operand layout for coordinate tensor
|
||||
tScP = cute.make_tensor(tScS.iterator, p_tmem_s.outer)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# Manual SMEM addressing for P (CUTLASS LLM guidance)
|
||||
|
||||
Reference in New Issue
Block a user