Fix CuTeDSL scoping: hoist P store vars out of if block

This commit is contained in:
2026-05-23 05:12:30 +00:00
parent d15bb7b84a
commit 7f1febccf0

View File

@@ -341,26 +341,22 @@ class FmhaKernel:
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# ── TMEM-P: P store setup (register bridge) ──
if not use_smem_p:
p_cols_fp32 = self.pv_mma_tiler[1] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype,
)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# ── P store setup (always define both paths — CuTeDSL scoping) ──
# TMEM-P: register bridge for P → TMEM
p_cols_fp32 = self.pv_mma_tiler[1] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype,
)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# ── SMEM-P: P → SMEM copy setup (TODO: proper QK→PV partition remap) ──
if use_smem_p:
# TODO: make_tiled_copy_C(store_atom, qk_mma) to partition threads by QK's C-fragment
# For now, zero sP as a stub — PV will read garbage/zero
pass
# SMEM-P: TODO — make_tiled_copy_C(store_atom, qk_mma) for QK→PV partition remap
# ── O rescale / normalization setup (correction_rescale pattern from Stage C) ──
corr_tile_size = 16