From a476324682a85d4d7e911db50baeb75ac48aaf26 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:55:40 +0000 Subject: [PATCH] Fix TMA shape diag: use ct tensors for LayoutEnum --- tests/unit/test_tma_shapes.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_tma_shapes.py b/tests/unit/test_tma_shapes.py index ec3430fb..16ab9794 100644 --- a/tests/unit/test_tma_shapes.py +++ b/tests/unit/test_tma_shapes.py @@ -30,10 +30,14 @@ def diag(): qk_acc_dtype = Float32 q_dtype = BFloat16 - a_major = LayoutEnum.from_tensor(q).mma_major_mode() - b_major = LayoutEnum.from_tensor(k).mma_major_mode() + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + + a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() + b_major = LayoutEnum.from_tensor(mK).mma_major_mode() v_fmha = cute.make_tensor( - v_kernel, + mV.iterator, cute.make_layout( (HEAD_DIM, s_k, 1), stride=(1, HEAD_DIM, HEAD_DIM * s_k),