D1.5: Fix SMEM-P - use coordinate-indexed store (same proven pattern)
This commit is contained in:
@@ -366,39 +366,19 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: write P to sP using a TiledCopy derived from PV MMA's
|
||||
# A-operand layout, but with softmax thread mapping.
|
||||
#
|
||||
# Strategy: Use the PV MMA's A-operand SMEM layout (which matches sP)
|
||||
# and create a new TiledCopy with softmax thread/value layout.
|
||||
# The softmax threads (128 total) each own one row of P.
|
||||
# Within each row, values are in sP's subtiled format.
|
||||
#
|
||||
# We use pv_mma's make_tiled_copy_A to get the copy atom and tiling,
|
||||
# then override the thread layout for the softmax threads.
|
||||
_smem_p_tiled_copy = utils.sm100.make_tiled_copy_A(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
pv_mma, self.q_dtype,
|
||||
128, # tiler_mn - matches the P matrix tile size
|
||||
)
|
||||
# Get the softmax thread's partition
|
||||
_thr_smem_p = _smem_p_tiled_copy.get_slice(sfw_idx)
|
||||
# Create a logical (non-swizzled) view of sP for partitioning
|
||||
_sP_logical = cute.make_tensor(_sP_nostage.iterator, _sP_nostage.layout)
|
||||
_tRS_sP = _thr_smem_p.partition_D(_sP_logical)
|
||||
# Create source register tensor matching the copy's value order
|
||||
_tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype)
|
||||
# Fill _tRS_rP from rP_bf16.
|
||||
# rP_bf16 is in TMEM-load order: ((32,1),4,1,1) with 128 values
|
||||
# _tRS_rP is in copy value order. We need to map between them.
|
||||
# For the copy, each thread should own P[thread_row, :].
|
||||
# rP_bf16[(j0,0),j1,0,0] = P[thread_row, j0+32*j1]
|
||||
# We need to figure out the copy's value order for our thread.
|
||||
# PRINT THE SHAPES to understand the mapping
|
||||
# For now, fill with zeros as a baseline test
|
||||
for v_idx in cutlass.range(cute.size(_tRS_rP), vectorize=True):
|
||||
_tRS_rP[v_idx] = BFloat16(0.0)
|
||||
cute.copy(_smem_p_tiled_copy, _tRS_rP, _tRS_sP)
|
||||
# SMEM-P: write P to sP using coordinate-indexed store.
|
||||
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates
|
||||
# and maps them to sP's swizzled layout.
|
||||
for j0 in range(32):
|
||||
for j1 in range(4):
|
||||
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
|
||||
m_coord = coord[0]
|
||||
k_coord = coord[1]
|
||||
k0 = k_coord % 16
|
||||
k1 = (k_coord // 16) % 4
|
||||
k2 = k_coord // 64
|
||||
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
if kt > 0:
|
||||
for i in range(n_corr_tiles):
|
||||
|
||||
Reference in New Issue
Block a user