From 27116110ab779c76d6273aa2f66c9ea97cc3d0dd Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:21:52 +0000 Subject: [PATCH] Fix identity diag: same 8D TMA indexing fix --- tests/unit/test_fmha_v3_diag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_fmha_v3_diag.py b/tests/unit/test_fmha_v3_diag.py index 95ff71ca..5251492c 100644 --- a/tests/unit/test_fmha_v3_diag.py +++ b/tests/unit/test_fmha_v3_diag.py @@ -138,7 +138,7 @@ class FmhaV3Diag: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,0,None,0)] + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,None,None,None,None,None,None)]; tVgV = tVgV[(None,None,None,None,None,None,None,None)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -171,8 +171,8 @@ class FmhaV3Diag: kv_coord = Int32(0 + 0) for kt in cutlass.range(self.n_kv_tiles, unroll=1): kvh = kvp.acquire_and_advance(pk) - cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) - cute.copy(tma_v, tVgV[(None, kv_coord)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_k, tBgK[None, None, None, None, kv_coord, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier) + cute.copy(tma_v, tVgV[None, None, None, None, kv_coord, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier) kv_coord += 1 pk = cutlass.Boolean(1) kvp.tail()