DIAG OUTPUT (n=256, inside @cute.kernel): tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes tVgV: (((64,128),1), 1, 1, 1) — 4 modes After (None,0,None,0) → keeps modes 0 and 2 free → 2D: tAgQ: (((64,128),1), Int32(?)) tBgK: (((64,128),1), Int32(?)) tVgV: (((64,128),1), 1) Then [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles). tAgQ[(None, Int32(0))] for Q (1 tile, coordinate is always 0). Removed diag prints from test_fmha_v3.py.
20 KiB
20 KiB