fix: LayoutEnum import in cotiled diag test
This commit is contained in:
@@ -30,9 +30,10 @@ def main():
|
||||
v_fmha = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
||||
mV = ct.from_dlpack(v_fmha).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_fmha))
|
||||
|
||||
a_major = cute.LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = cute.LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
v_major = cute.LayoutEnum.from_tensor(mV).mma_major_mode()
|
||||
from cutlass.utils import LayoutEnum
|
||||
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major, Float32,
|
||||
|
||||
Reference in New Issue
Block a user