Fix p_tmem_s: use ComposedLayout from make_smem_layout_a, pass as kernel arg
This commit is contained in:
@@ -90,8 +90,8 @@ class FmhaKernel:
|
||||
if self.num_tmem_alloc_cols > 512:
|
||||
print(f"⚠️ TMEM BUDGET: {self.num_tmem_alloc_cols} cols (hd={hd})")
|
||||
|
||||
# P TMEM alias (PV A-operand viewed as TMEM for partition mapping)
|
||||
self.p_tmem_s = tStS # reuses QK C-fragment TMEM layout for P partition
|
||||
# P TMEM layout (PV A-operand SMEM layout — used to alias QK C-fragment in TMEM)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
|
||||
# TMA bytes
|
||||
cta = cute.size(qk_mma.thr_id.shape)
|
||||
@@ -143,7 +143,7 @@ class FmhaKernel:
|
||||
self._kernel(
|
||||
qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC,
|
||||
self.cluster_layout_vmnk,
|
||||
self.q_smem_s, self.k_smem_s, self.v_smem_s, self.p_smem_s, self.c_smem_s,
|
||||
self.q_smem_s, self.k_smem_s, self.v_smem_s, self.p_smem_s, self.p_tmem_s, self.c_smem_s,
|
||||
self.epi_tile,
|
||||
).launch(
|
||||
grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream,
|
||||
@@ -152,7 +152,7 @@ class FmhaKernel:
|
||||
@cute.kernel
|
||||
def _kernel(
|
||||
self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC,
|
||||
cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_smem_s, c_smem_s, epi_tile,
|
||||
cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_smem_s, p_tmem_s, c_smem_s, epi_tile,
|
||||
):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
@@ -253,7 +253,7 @@ class FmhaKernel:
|
||||
|
||||
# ── TMEM-P path: PV A-operand from TMEM ──
|
||||
if not use_smem_p:
|
||||
tP = cute.make_tensor(tStS.iterator, self.p_tmem_s.outer)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
|
||||
Reference in New Issue
Block a user