DEBUG: print tBgK/tVgV layout to check strides
This commit is contained in:
@@ -181,6 +181,10 @@ class FmhaV3StageCMulti:
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
# DEBUG: print shapes and strides to verify kv_tiles dimension has nonzero stride
|
||||
print(f"tBgK pre-sliced shape: {cute.shape(tBgK)} layout: {tBgK.layout}")
|
||||
print(f"tVgV pre-sliced shape: {cute.shape(tVgV)} layout: {tVgV.layout}")
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user