fix: fence_proxy not fence

This commit is contained in:
2026-05-24 02:44:20 +00:00
parent cd6d81fc4b
commit 7357b1a866

View File

@@ -366,21 +366,52 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using make_cotiled_copy.
# Build atom_layout_tv: (tid, vid) -> sP_addr.
# Each softmax thread (128 total) owns one row of the 128×128 P matrix.
# Within a row, 128 values decompose as (k0, k1, k2) = (16, 4, 2)
# matching sP's subtiled layout strides (1, 16, 8192).
# SMEM-P: write P to sP using coordinate-indexed store.
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
# SMEM-P: write P to sP using make_tiled_copy_tv.
# Strategy: Create a R→S copy whose TV layout maps softmax
# thread/value ownership to sP's address space.
#
# Key assumption: 128 softmax threads each own one row of the
# 128×128 P matrix. Within each row, values are in sP's
# subtiled format: (16, 4, 2) with strides (1, 16, 8192).
# Thread stride = 64 (one row in sP's layout).
#
# We create the source register tensor in the copy's value order
# and fill it from rP_bf16 during the softmax loop.
_smem_p_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.q_dtype,
num_bits_per_copy=16,
)
_smem_p_thr_layout = cute.make_layout((128,), stride=(1,))
_smem_p_val_layout = cute.make_layout((16, 4, 2), stride=(1, 16, 8192))
_tiled_smem_p = cute.make_tiled_copy_tv(
_smem_p_atom, _smem_p_thr_layout, _smem_p_val_layout,
)
_thr_smem_p = _tiled_smem_p.get_slice(sfw_idx)
# Create a logical (non-swizzled) view of sP for partitioning
_sP_logical = cute.make_tensor(_sP_nostage.iterator, cute.make_layout(
(128, 16, 4, 2), stride=(64, 1, 16, 8192)
))
_tRS_sP = _thr_smem_p.partition_D(_sP_logical)
# Create source register tensor in copy's value order
_tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype)
# Fill _tRS_rP from rP_bf16 using coordinate-based mapping
# Each thread owns row sfw_idx of the P matrix.
# rP_bf16 is in TMEM-load register order: ((32,1),4,1,1)
# _tRS_rP is in copy value order: (16, 4, 2) per thread
# For thread i, rP_bf16[(j0,0),j1,0,0] corresponds to P[i, j0+32*j1]
# And _tRS_rP[k0,k1,k2] corresponds to P[i, k0+16*k1+64*k2]
# So k0+16*k1+64*k2 = j0+32*j1, which means:
# j0 = (k0+16*k1+64*k2) % 32
# j1 = (k0+16*k1+64*k2) // 32
for k0 in range(16):
for k1 in range(4):
for k2 in range(2):
flat_k = k0 + 16 * k1 + 64 * k2 # k index in [0,128)
j0 = flat_k % 32
j1 = flat_k // 32
_tRS_rP[k0, k1, k2] = rP_bf16[(j0, 0), j1, 0, 0]
cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP)
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:
for i in range(n_corr_tiles):