diff --git a/tests/unit/test_tma_shapes.py b/tests/unit/test_tma_shapes.py index ec3430fb..16ab9794 100644 --- a/tests/unit/test_tma_shapes.py +++ b/tests/unit/test_tma_shapes.py @@ -30,10 +30,14 @@ def diag(): qk_acc_dtype = Float32 q_dtype = BFloat16 - a_major = LayoutEnum.from_tensor(q).mma_major_mode() - b_major = LayoutEnum.from_tensor(k).mma_major_mode() + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + + a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() + b_major = LayoutEnum.from_tensor(mK).mma_major_mode() v_fmha = cute.make_tensor( - v_kernel, + mV.iterator, cute.make_layout( (HEAD_DIM, s_k, 1), stride=(1, HEAD_DIM, HEAD_DIM * s_k),