D1.3: Add debug prints for SMEM-P coordinate mapping

This commit is contained in:
2026-05-23 23:24:02 +00:00
parent de869c01c8
commit fca9652719

View File

@@ -352,14 +352,14 @@ class FmhaKernel:
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using coordinate-indexed store.
# tTMEM_LOADcS contains (m, k) coordinates from identity tensor.
# Each element is an (m, k) coordinate pair.
# rP_bf16 has the same shape/layout as tTMEM_LOADcS (BF16 view of FP32 registers).
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
# DEBUG: print first 8 coords from thread 0
if sfw_idx == 0 and kt == 0 and j0 < 2 and j1 < 2:
print(f"[SMEM-P] j0={j0} j1={j1} m={m_coord} k={k_coord} P={rP_bf16[(j0, 0), j1, 0, 0]}")
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64