DEBUG: print tBgK/tVgV shapes before/after slice
This commit is contained in:
@@ -171,9 +171,13 @@ class FmhaV3StageC:
|
||||
# CUTLASS reference: tKgK = tKgK_kdl[None, None, 0, 0] (keeps TMA_atom + GMEM_iter)
|
||||
# tVgV = tVgV_dkl[None, 0, None, 0] (keeps TMA_atom + GMEM_iter at mode 2)
|
||||
# SMEM tensors from tma_partition are already 2D — don't re-slice them.
|
||||
print(f"DEBUG tBgK shape before slice: {tBgK.shape}")
|
||||
print(f"DEBUG tVgV shape before slice: {tVgV.shape}")
|
||||
tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode is fine
|
||||
tBgK = tBgK[(None,None,0,0)] # K: keep mode 1 (GMEM iter) free
|
||||
tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free
|
||||
print(f"DEBUG tBgK shape after slice: {tBgK.shape}")
|
||||
print(f"DEBUG tVgV shape after slice: {tVgV.shape}")
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
Reference in New Issue
Block a user