Fix TMA shape diag: use ct tensors for LayoutEnum
This commit is contained in:
@@ -30,10 +30,14 @@ def diag():
|
||||
qk_acc_dtype = Float32
|
||||
q_dtype = BFloat16
|
||||
|
||||
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
|
||||
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
v_fmha = cute.make_tensor(
|
||||
v_kernel,
|
||||
mV.iterator,
|
||||
cute.make_layout(
|
||||
(HEAD_DIM, s_k, 1),
|
||||
stride=(1, HEAD_DIM, HEAD_DIM * s_k),
|
||||
|
||||
Reference in New Issue
Block a user