D1.3: Fix coordinate indexing - tTMEM_LOADcS first mode is (32,1) nested tuple

This commit is contained in:
2026-05-23 23:20:12 +00:00
parent 4b8970d83c
commit f74fd75054

View File

@@ -352,19 +352,17 @@ class FmhaKernel:
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using coordinate-indexed store.
# Each thread knows its (m, k) from tTMEM_LOADcS.
# Index sP at ((m, k%16), 0, ((k//16)%4, k//64), 0).
# CuTeDSL tensor indexing handles the swizzle automatically.
for j0 in range(cute.size(tTMEM_LOADcS, mode=[0])):
for j1 in range(cute.size(tTMEM_LOADcS, mode=[1])):
m_coord = tTMEM_LOADcS[j0, j1, 0, 0, 0]
k_coord = tTMEM_LOADcS[j0, j1, 0, 0, 1]
# Decompose k into sP's sub-coordinates
# tTMEM_LOADcS shape: ((32,1),4,1,1) with layout ((32,1),4,1,1)
# First mode is (32,1) — 32 m-coordinates per fragment, 1 k-slice.
# So indexing: tTMEM_LOADcS[(j0, 0), j1, 0, 0] gives (m, k).
for j0 in range(32):
for j1 in range(4):
m_coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0, 0]
k_coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0, 1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
# Write P value to sP (swizzle handled by tensor layout)
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[j0, j1, 0, 0]
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:
tTMrO = cute.make_rmem_tensor(