Fix p_tmem_s: use ComposedLayout from make_smem_layout_a, pass as kernel arg

This commit is contained in:
2026-05-23 05:06:45 +00:00
parent d0869054d5
commit 2cc9786491

View File

@@ -90,8 +90,8 @@ class FmhaKernel:
if self.num_tmem_alloc_cols > 512:
print(f"⚠️ TMEM BUDGET: {self.num_tmem_alloc_cols} cols (hd={hd})")
# P TMEM alias (PV A-operand viewed as TMEM for partition mapping)
self.p_tmem_s = tStS # reuses QK C-fragment TMEM layout for P partition
# P TMEM layout (PV A-operand SMEM layout — used to alias QK C-fragment in TMEM)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
# TMA bytes
cta = cute.size(qk_mma.thr_id.shape)
@@ -143,7 +143,7 @@ class FmhaKernel:
self._kernel(
qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC,
self.cluster_layout_vmnk,
self.q_smem_s, self.k_smem_s, self.v_smem_s, self.p_smem_s, self.c_smem_s,
self.q_smem_s, self.k_smem_s, self.v_smem_s, self.p_smem_s, self.p_tmem_s, self.c_smem_s,
self.epi_tile,
).launch(
grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream,
@@ -152,7 +152,7 @@ class FmhaKernel:
@cute.kernel
def _kernel(
self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC,
cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_smem_s, c_smem_s, epi_tile,
cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_smem_s, p_tmem_s, c_smem_s, epi_tile,
):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
@@ -253,7 +253,7 @@ class FmhaKernel:
# ── TMEM-P path: PV A-operand from TMEM ──
if not use_smem_p:
tP = cute.make_tensor(tStS.iterator, self.p_tmem_s.outer)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(