Fix tOrP0 indexing: 3-dim slice (None,None,kb) not 4-dim
This commit is contained in:
@@ -310,7 +310,7 @@ class FmhaKernel:
|
||||
if not use_smem_p:
|
||||
# TMEM-P: P from TMEM
|
||||
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None, None, kb, 0)], tCrV[(None, None, kb, kvh.index)], tOtO0)
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None, None, kb)], tCrV[(None, None, kb, kvh.index)], tOtO0)
|
||||
else:
|
||||
# SMEM-P: P from SMEM
|
||||
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
|
||||
|
||||
Reference in New Issue
Block a user