D1.3: Fix coord extraction - identity tensor stores (m,k) pairs as values

This commit is contained in:
2026-05-23 23:21:15 +00:00
parent a7171fa5e1
commit e0a11e32f8

View File

@@ -352,13 +352,16 @@ class FmhaKernel:
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using coordinate-indexed store.
# tTMEM_LOADcS shape: ((32,1),4,1,1) with layout ((32,1),4,1,1)
# First mode is (32,1) — 32 m-coordinates per fragment, 1 k-slice.
# So indexing: tTMEM_LOADcS[(j0, 0), j1, 0, 0] gives (m, k).
# tTMEM_LOADcS contains (m, k) coordinates from identity tensor.
# Shape: ((32,1),4,1,1) — indexed with 4 indices.
# Each element is an (m, k) coordinate pair.
# Extract m with .load()[0] and k with .load()[1],
# or use indexing tTMEM_LOADcS[...].value[0/1].
for j0 in range(32):
for j1 in range(4):
m_coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0, 0]
k_coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0, 1]
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64