diff --git a/tests/test_stage_b_v29.py b/tests/test_stage_b_v29.py index 5dec8fe9..bafc1283 100644 --- a/tests/test_stage_b_v29.py +++ b/tests/test_stage_b_v29.py @@ -31,7 +31,7 @@ class StageBIdentitySoftmax: def _setup(self, qk_mma, pv_mma): qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) - self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[1], self.qk_mma_tiler[1]) self.mma_tiler = self.qk_mma_tiler self.cta_tile_shape_mnk = ( self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), @@ -69,7 +69,7 @@ class StageBIdentitySoftmax: a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) self.num_tma_load_bytes = ( - cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + cute.size_in_bytes(self.b_dtype, cute.slice_(self.v_smem_s, (None, None, None, 0))) ) * cute.size(qk_mma.thr_id.shape) @cute.jit @@ -240,12 +240,11 @@ class StageBIdentitySoftmax: s0_handle.commit() s0_handle = mma_si_prod.acquire_and_advance() - pv_mma.set(tcgen05.Field.ACCUMULATE, False) + pv_mma.set(tcgen05.Field.ACCUMULATE, True) tCrV_s = tCrV[(None, None, None, 0)] nblk_pv = cute.size(tOrP0, mode=[2]) for kb in cutlass.range(nblk_pv, unroll_full=True): cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0) - pv_mma.set(tcgen05.Field.ACCUMULATE, True) acc_pipe.producer_commit(acc_prod_st) acc_prod_st.advance()