From bb92af5b0cd7e44946366a99e39a04bb3cb9a286 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:21:23 +0000 Subject: [PATCH] =?UTF-8?q?FIX:=20Use=20full=208D=20indexing=20for=20tBgK/?= =?UTF-8?q?tVgV=20=E2=80=94=20mode=204=20is=20the=20GMEM=20tile=20dim?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_fmha_v3_stage_c.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index dac7347d..f8549eff 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -179,7 +179,12 @@ class FmhaV3StageCMulti: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,0,None,0)] + # TMA source tensor slices: keep the GMEM tile dimension (mode 4) free + # tBgK shape: (1, 1, 1, 1, 2, 1, 1, 1) — 8 modes, mode 4 = kv_tiles + # tVgV shape: (1, 1, 1, 1, 2, 1, 1, 1) — 8 modes, mode 4 = kv_tiles + tAgQ = tAgQ[(None,0,None,0)] + tBgK = tBgK[(None,None,None,None,None,None,None,None)] # No-op, use full indexing in copy + tVgV = tVgV[(None,None,None,None,None,None,None,None)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV)