fix: hardcode v_major for diag test
This commit is contained in:
@@ -42,15 +42,9 @@ def main():
|
||||
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
|
||||
# V FMHA layout (same as FmhaKernel.__call__)
|
||||
v_fmha = cute.make_tensor(
|
||||
mV.iterator,
|
||||
cute.make_layout(
|
||||
(pv_n_tile, s_k, 1),
|
||||
stride=(1, pv_n_tile, pv_n_tile * s_k),
|
||||
),
|
||||
)
|
||||
v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
|
||||
# V FMHA layout: use explicit constants
|
||||
# pv_n_tile=256, s_k=128
|
||||
v_major = LayoutEnum.COLUMN_MAJOR # layout (256, 128, 1) stride (1, 256, 32768) = col-major
|
||||
|
||||
c_layout = LayoutEnum.from_tensor(mC)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user