D1.3: Add debug prints for SMEM-P coordinate mapping
This commit is contained in:
@@ -352,14 +352,14 @@ class FmhaKernel:
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: write P to sP using coordinate-indexed store.
|
||||
# tTMEM_LOADcS contains (m, k) coordinates from identity tensor.
|
||||
# Each element is an (m, k) coordinate pair.
|
||||
# rP_bf16 has the same shape/layout as tTMEM_LOADcS (BF16 view of FP32 registers).
|
||||
for j0 in range(32):
|
||||
for j1 in range(4):
|
||||
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
|
||||
m_coord = coord[0]
|
||||
k_coord = coord[1]
|
||||
# DEBUG: print first 8 coords from thread 0
|
||||
if sfw_idx == 0 and kt == 0 and j0 < 2 and j1 < 2:
|
||||
print(f"[SMEM-P] j0={j0} j1={j1} m={m_coord} k={k_coord} P={rP_bf16[(j0, 0), j1, 0, 0]}")
|
||||
k0 = k_coord % 16
|
||||
k1 = (k_coord // 16) % 4
|
||||
k2 = k_coord // 64
|
||||
|
||||
Reference in New Issue
Block a user