Fix TMA shape diag: use ct tensors for LayoutEnum

This commit is contained in:
2026-05-22 21:55:40 +00:00
parent fd6b1e82d8
commit a476324682

View File

@@ -30,10 +30,14 @@ def diag():
qk_acc_dtype = Float32
q_dtype = BFloat16
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
v_fmha = cute.make_tensor(
v_kernel,
mV.iterator,
cute.make_layout(
(HEAD_DIM, s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * s_k),