fix: tTMrO scoping + restore SMEM-P coordinate write
This commit is contained in:
@@ -366,13 +366,22 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: TEMPORARILY zero-fill sP (debugging deadlock).
|
||||
# The coordinate-indexed write causes a deadlock at hd=256.
|
||||
# TODO: Fix the SMEM-P write path.
|
||||
# SMEM-P: write P to sP using coordinate-indexed store.
|
||||
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates.
|
||||
for j0 in range(32):
|
||||
for j1 in range(4):
|
||||
_sP_nostage[(j0, j1), 0, (0, 0)] = BFloat16(0.0)
|
||||
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
|
||||
m_coord = coord[0]
|
||||
k_coord = coord[1]
|
||||
k0 = k_coord % 16
|
||||
k1 = (k_coord // 16) % 4
|
||||
k2 = k_coord // 64
|
||||
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
# O rescale register tensor (defined unconditionally for CuTeDSL scoping)
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
if kt > 0:
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
|
||||
Reference in New Issue
Block a user