D1: Add sP shape debug print

This commit is contained in:
2026-05-24 03:46:27 +00:00
parent 0f52c3453c
commit ebde1d67fd

View File

@@ -288,6 +288,8 @@ class FmhaKernel:
# sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
# This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192))
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
# Debug: print sP shape at trace time
print(f"SMEM-P: sP shape={cute.shape(sP)}, sP_nostage shape={cute.shape(_sP_nostage)}")
row_max = -Float32.inf
row_sum = Float32(0.0)