feat: SMEM-P make_tiled_copy_C + zero-fill dest tensor

This commit is contained in:
2026-05-24 03:23:48 +00:00
parent 99b2e12fd8
commit 0de0f20799

View File

@@ -366,11 +366,14 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using TiledCopy derived from QK MMA.
# SMEM-P: write P to sP using make_tiled_copy_C(qk_mma) with
# manual source→destination value mapping.
#
# make_tiled_copy_C with qk_mma gives threads partitioned by the
# QK C-fragment (same as TMEM load). Source: rP_bf16 (registers).
# Destination: sP (PV A-operand SMEM layout).
# make_tiled_copy_C gives the right thread partition (QK C-fragment)
# but the source (rP_bf16) and destination (sP) have different ranks.
# Solution: use partition_D to get the sP partition, create a register
# tensor matching its shape, fill from rP_bf16 via coordinate mapping,
# then copy.
_smem_p_store_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.q_dtype,
@@ -378,8 +381,21 @@ class FmhaKernel:
)
_tiled_smem_p = cute.make_tiled_copy_C(_smem_p_store_atom, qk_mma)
_thr_smem_p = _tiled_smem_p.get_slice(sfw_idx)
_tRS_rP = _thr_smem_p.partition_S(rP_bf16)
_tRS_sP = _thr_smem_p.partition_D(_sP_nostage)
# Create source register tensor matching destination shape
_tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype)
# Fill _tRS_rP from rP_bf16 using coordinate mapping.
# The copy's value layout indexes into the 128×128 P matrix.
# We use the TMEM-load coordinate tensor to map each value index
# to the corresponding (m, k) and then find the rP_bf16 element.
# Since both _tRS_rP and rP_bf16 represent the SAME P values for
# this thread, just in different layouts, we can use the coordinate
# tensor to establish the mapping.
#
# For now, zero-fill to test compilation and synchronization.
# Once the pipeline runs, we'll fill properly.
for v_idx in cutlass.range(cute.size(_tRS_rP), vectorize=True):
_tRS_rP[v_idx] = BFloat16(0.0)
cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP)
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0: