feat: SMEM-P make_tiled_copy_C + zero-fill dest tensor
This commit is contained in:
@@ -366,11 +366,14 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: write P to sP using TiledCopy derived from QK MMA.
|
||||
# SMEM-P: write P to sP using make_tiled_copy_C(qk_mma) with
|
||||
# manual source→destination value mapping.
|
||||
#
|
||||
# make_tiled_copy_C with qk_mma gives threads partitioned by the
|
||||
# QK C-fragment (same as TMEM load). Source: rP_bf16 (registers).
|
||||
# Destination: sP (PV A-operand SMEM layout).
|
||||
# make_tiled_copy_C gives the right thread partition (QK C-fragment)
|
||||
# but the source (rP_bf16) and destination (sP) have different ranks.
|
||||
# Solution: use partition_D to get the sP partition, create a register
|
||||
# tensor matching its shape, fill from rP_bf16 via coordinate mapping,
|
||||
# then copy.
|
||||
_smem_p_store_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.q_dtype,
|
||||
@@ -378,8 +381,21 @@ class FmhaKernel:
|
||||
)
|
||||
_tiled_smem_p = cute.make_tiled_copy_C(_smem_p_store_atom, qk_mma)
|
||||
_thr_smem_p = _tiled_smem_p.get_slice(sfw_idx)
|
||||
_tRS_rP = _thr_smem_p.partition_S(rP_bf16)
|
||||
_tRS_sP = _thr_smem_p.partition_D(_sP_nostage)
|
||||
# Create source register tensor matching destination shape
|
||||
_tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype)
|
||||
# Fill _tRS_rP from rP_bf16 using coordinate mapping.
|
||||
# The copy's value layout indexes into the 128×128 P matrix.
|
||||
# We use the TMEM-load coordinate tensor to map each value index
|
||||
# to the corresponding (m, k) and then find the rP_bf16 element.
|
||||
# Since both _tRS_rP and rP_bf16 represent the SAME P values for
|
||||
# this thread, just in different layouts, we can use the coordinate
|
||||
# tensor to establish the mapping.
|
||||
#
|
||||
# For now, zero-fill to test compilation and synchronization.
|
||||
# Once the pipeline runs, we'll fill properly.
|
||||
for v_idx in cutlass.range(cute.size(_tRS_rP), vectorize=True):
|
||||
_tRS_rP[v_idx] = BFloat16(0.0)
|
||||
cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
if kt > 0:
|
||||
|
||||
Reference in New Issue
Block a user