D1: Add more debug prints (QK/PV mode2 sizes)
This commit is contained in:
@@ -288,8 +288,12 @@ class FmhaKernel:
|
||||
# sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
|
||||
# This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192))
|
||||
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
|
||||
# Debug: print sP shape at trace time
|
||||
# Debug: print key shapes at trace time
|
||||
print(f"SMEM-P: sP shape={cute.shape(sP)}, sP_nostage shape={cute.shape(_sP_nostage)}")
|
||||
print(f"QK: tCrQ mode2 size={cute.size(tCrQ, mode=[2])}, tCrK mode2 size={cute.size(tCrK, mode=[2])}")
|
||||
print(f"QK: tOrP0 mode2 size={cute.size(tOrP0, mode=[2])}")
|
||||
print(f"PV: tCrP mode2 size={cute.size(tCrP, mode=[2])}")
|
||||
print(f"PV: pv_n_tile={self.pv_n_tile}, n_pv_tiles={self.n_pv_tiles}")
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
|
||||
Reference in New Issue
Block a user