From 4f6853e1ae91b7aa29f08089edb1f9e7b3bc0d3c Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 16:57:31 +0000 Subject: [PATCH] FIX: only slice GMEM tensors (SMEM already 2D from tma_partition) --- tests/unit/test_fmha_v3_diag.py | 4 +--- tests/unit/test_fmha_v3_stage_c_full.py | 16 +++++++--------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_fmha_v3_diag.py b/tests/unit/test_fmha_v3_diag.py index 98d580c8..a4a5705f 100644 --- a/tests/unit/test_fmha_v3_diag.py +++ b/tests/unit/test_fmha_v3_diag.py @@ -137,9 +137,7 @@ class FmhaV3Diag: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - tAsQ = tAsQ[(None,0,None,0)]; tAgQ = tAgQ[(None,0,None,0)] - tBsK = tBsK[(None,None,0,0)]; tBgK = tBgK[(None,None,0,0)] - tVsV = tVsV[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,0,0)]; tVgV = tVgV[(None,0,None,0)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index baf59749..88b63547 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -167,15 +167,13 @@ class FmhaV3StageC: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # Q: GMEM iter at mode 1 (from QK MMA A-partition, layout Q,D,L) - # K: GMEM iter at mode 1 (from QK MMA B-partition, layout K,D,L) - # V: GMEM iter at mode 2 (from PV MMA B-partition, layout D,K,L) - # Must keep GMEM iter dimension FREE for multi-tile KV loads. - # CUTLASS reference slices: tQgQ[None,None,0,0], tKgK[None,None,0,0], tVgV[None,0,None,0] - # Both GMEM and SMEM sides must be sliced consistently for cute.copy rank matching. - tAsQ = tAsQ[(None,0,None,0)]; tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile only - tBsK = tBsK[(None,None,0,0)]; tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) - tVsV = tVsV[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) + # Only slice GMEM tensors (SMEM tensors from tma_partition are already 2D). + # K from QK MMA: GMEM iter at mode 1 → slice [(None,None,0,0)] keeps modes 0,1 + # V from PV MMA: GMEM iter at mode 2 → slice [(None,0,None,0)] keeps modes 0,2 + # Q: 1 tile only, original slice is fine. + tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode GMEM iter to 0 + tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free for kvh.count + tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free for kvh.count tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV)