diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 13f7b69a..7976bc5b 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -171,9 +171,13 @@ class FmhaV3StageC: # CUTLASS reference: tKgK = tKgK_kdl[None, None, 0, 0] (keeps TMA_atom + GMEM_iter) # tVgV = tVgV_dkl[None, 0, None, 0] (keeps TMA_atom + GMEM_iter at mode 2) # SMEM tensors from tma_partition are already 2D — don't re-slice them. + print(f"DEBUG tBgK shape before slice: {tBgK.shape}") + print(f"DEBUG tVgV shape before slice: {tVgV.shape}") tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode is fine tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free + print(f"DEBUG tBgK shape after slice: {tBgK.shape}") + print(f"DEBUG tVgV shape after slice: {tVgV.shape}") tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV)