diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 87ff4ce6..67de010b 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -310,7 +310,7 @@ class FmhaKernel: if not use_smem_p: # TMEM-P: P from TMEM for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True): - cute.gemm(pv_mma, tOtO0, tOrP0[(None, None, kb, 0)], tCrV[(None, None, kb, kvh.index)], tOtO0) + cute.gemm(pv_mma, tOtO0, tOrP0[(None, None, kb)], tCrV[(None, None, kb, kvh.index)], tOtO0) else: # SMEM-P: P from SMEM for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):