diff --git a/tests/unit/test_d1_3_cotiled.py b/tests/unit/test_d1_3_cotiled.py index cfef2088..2e22c92e 100644 --- a/tests/unit/test_d1_3_cotiled.py +++ b/tests/unit/test_d1_3_cotiled.py @@ -82,9 +82,10 @@ class SmemPDiag: s_cols = self.qk_mma_tiler[1] o_cols = find_tmem_tensor_col_offset(tOtO) total = max(s_cols, o_cols) - self.num_tmem_alloc_cols = 1 - while self.num_tmem_alloc_cols < total: - self.num_tmem_alloc_cols *= 2 + _n = 1 + while _n < total: + _n *= 2 + self.num_tmem_alloc_cols = _n self.tOrP0_offset = 0 # SMEM-P # Build SMEM layouts