diff --git a/tests/unit/test_d1_diag.py b/tests/unit/test_d1_diag.py index b7c548db..fbc07d1b 100644 --- a/tests/unit/test_d1_diag.py +++ b/tests/unit/test_d1_diag.py @@ -27,8 +27,8 @@ mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) -a_major = LayoutEnum.from_tensor(q).mma_major_mode() -b_major = LayoutEnum.from_tensor(k).mma_major_mode() +a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() +b_major = LayoutEnum.from_tensor(mK).mma_major_mode() pv_n_tile = min(hd, 256) v_fmha = cute.make_tensor(