diff --git a/tests/unit/test_cotiled_diag.py b/tests/unit/test_cotiled_diag.py index 7350acf4..965ac300 100644 --- a/tests/unit/test_cotiled_diag.py +++ b/tests/unit/test_cotiled_diag.py @@ -30,9 +30,10 @@ def main(): v_fmha = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1) mV = ct.from_dlpack(v_fmha).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_fmha)) - a_major = cute.LayoutEnum.from_tensor(mQ).mma_major_mode() - b_major = cute.LayoutEnum.from_tensor(mK).mma_major_mode() - v_major = cute.LayoutEnum.from_tensor(mV).mma_major_mode() + from cutlass.utils import LayoutEnum + a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() + b_major = LayoutEnum.from_tensor(mK).mma_major_mode() + v_major = LayoutEnum.from_tensor(mV).mma_major_mode() qk_mma = utils.sm100.make_trivial_tiled_mma( BFloat16, BFloat16, a_major, b_major, Float32,