D1.5: Fix SMEM-P - use coordinate-indexed store (same proven pattern)

This commit is contained in:
2026-05-24 03:19:32 +00:00
parent 93e7fe97f7
commit 53bc54ed17

View File

@@ -366,39 +366,19 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using a TiledCopy derived from PV MMA's
# A-operand layout, but with softmax thread mapping.
#
# Strategy: Use the PV MMA's A-operand SMEM layout (which matches sP)
# and create a new TiledCopy with softmax thread/value layout.
# The softmax threads (128 total) each own one row of P.
# Within each row, values are in sP's subtiled format.
#
# We use pv_mma's make_tiled_copy_A to get the copy atom and tiling,
# then override the thread layout for the softmax threads.
_smem_p_tiled_copy = utils.sm100.make_tiled_copy_A(
cute.nvgpu.CopyUniversalOp(),
pv_mma, self.q_dtype,
128, # tiler_mn - matches the P matrix tile size
)
# Get the softmax thread's partition
_thr_smem_p = _smem_p_tiled_copy.get_slice(sfw_idx)
# Create a logical (non-swizzled) view of sP for partitioning
_sP_logical = cute.make_tensor(_sP_nostage.iterator, _sP_nostage.layout)
_tRS_sP = _thr_smem_p.partition_D(_sP_logical)
# Create source register tensor matching the copy's value order
_tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype)
# Fill _tRS_rP from rP_bf16.
# rP_bf16 is in TMEM-load order: ((32,1),4,1,1) with 128 values
# _tRS_rP is in copy value order. We need to map between them.
# For the copy, each thread should own P[thread_row, :].
# rP_bf16[(j0,0),j1,0,0] = P[thread_row, j0+32*j1]
# We need to figure out the copy's value order for our thread.
# PRINT THE SHAPES to understand the mapping
# For now, fill with zeros as a baseline test
for v_idx in cutlass.range(cute.size(_tRS_rP), vectorize=True):
_tRS_rP[v_idx] = BFloat16(0.0)
cute.copy(_smem_p_tiled_copy, _tRS_rP, _tRS_sP)
# SMEM-P: write P to sP using coordinate-indexed store.
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates
# and maps them to sP's swizzled layout.
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:
for i in range(n_corr_tiles):