diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 7976bc5b..04526b9a 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -174,7 +174,7 @@ class FmhaV3StageC: print(f"DEBUG tBgK shape before slice: {tBgK.shape}") print(f"DEBUG tVgV shape before slice: {tVgV.shape}") tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode is fine - tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free + tBgK = tBgK[(None,0,None,0)] # K: try keeping mode 2 free (like V) tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free print(f"DEBUG tBgK shape after slice: {tBgK.shape}") print(f"DEBUG tVgV shape after slice: {tVgV.shape}")