Fix CuTeDSL scoping: hoist P store vars out of if block
This commit is contained in:
@@ -341,26 +341,22 @@ class FmhaKernel:
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# ── TMEM-P: P store setup (register bridge) ──
|
||||
if not use_smem_p:
|
||||
p_cols_fp32 = self.pv_mma_tiler[1] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype,
|
||||
)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_store.partition_D(tStP0)
|
||||
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
# ── P store setup (always define both paths — CuTeDSL scoping) ──
|
||||
# TMEM-P: register bridge for P → TMEM
|
||||
p_cols_fp32 = self.pv_mma_tiler[1] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype,
|
||||
)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_store.partition_D(tStP0)
|
||||
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# ── SMEM-P: P → SMEM copy setup (TODO: proper QK→PV partition remap) ──
|
||||
if use_smem_p:
|
||||
# TODO: make_tiled_copy_C(store_atom, qk_mma) to partition threads by QK's C-fragment
|
||||
# For now, zero sP as a stub — PV will read garbage/zero
|
||||
pass
|
||||
# SMEM-P: TODO — make_tiled_copy_C(store_atom, qk_mma) for QK→PV partition remap
|
||||
|
||||
# ── O rescale / normalization setup (correction_rescale pattern from Stage C) ──
|
||||
corr_tile_size = 16
|
||||
|
||||
Reference in New Issue
Block a user