DEBUG: print tBgK/tVgV shapes before/after slice

This commit is contained in:
2026-05-22 17:28:45 +00:00
parent 62bbc5ce42
commit 46ded465da

View File

@@ -171,9 +171,13 @@ class FmhaV3StageC:
# CUTLASS reference: tKgK = tKgK_kdl[None, None, 0, 0] (keeps TMA_atom + GMEM_iter)
# tVgV = tVgV_dkl[None, 0, None, 0] (keeps TMA_atom + GMEM_iter at mode 2)
# SMEM tensors from tma_partition are already 2D — don't re-slice them.
print(f"DEBUG tBgK shape before slice: {tBgK.shape}")
print(f"DEBUG tVgV shape before slice: {tVgV.shape}")
tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode is fine
tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free
tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free
print(f"DEBUG tBgK shape after slice: {tBgK.shape}")
print(f"DEBUG tVgV shape after slice: {tVgV.shape}")
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)