Fix tOrP0 indexing: 3-dim slice (None,None,kb) not 4-dim

This commit is contained in:
2026-05-23 05:09:19 +00:00
parent 0b277f4199
commit 300482e40a

View File

@@ -310,7 +310,7 @@ class FmhaKernel:
if not use_smem_p:
# TMEM-P: P from TMEM
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None, None, kb, 0)], tCrV[(None, None, kb, kvh.index)], tOtO0)
cute.gemm(pv_mma, tOtO0, tOrP0[(None, None, kb)], tCrV[(None, None, kb, kvh.index)], tOtO0)
else:
# SMEM-P: P from SMEM
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):