D1.3: Try make_tiled_copy_C(qk_mma) for SMEM-P copy - zero-fill source for compile test

This commit is contained in:
2026-05-23 22:29:10 +00:00
parent bafcfa658f
commit f1341ad76e

View File

@@ -267,20 +267,28 @@ class FmhaKernel:
tTMEM_STOREcP = thr_store.partition_S(tScP)
# P SMEM copy atoms: SMEM-P
# Uses get_smem_store_op + make_tiled_copy_D from CUTLASS blackwell_helpers.
# This creates a SMEM store copy using the same thread partition as the TMEM load,
# so the same threads that compute P (softmax warps) can write P to sP directly.
# The MMA warp then reads P from sP via pv_mma.make_fragment_A(sP).
# Must define unconditionally (CuTeDSL scoping: compile both branches).
# For TMEM-P (use_smem_p=False), these are allocated but unused (dead-code-eliminated).
_p_smem_store_atom = get_smem_store_op(
self.c_layout, self.c_dtype, self.acc_dtype, tiled_tmem_load
# Approach: use make_tiled_copy_C(qk_mma) to create a copy that writes
# from QK C-fragment register layout to SMEM. The softmax threads have P values
# in registers after computing softmax. We write these to sP so the MMA warp
# can read them via pv_mma.make_fragment_A(sP).
# Must define unconditionally (CuTeDSL scoping).
_smem_copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.q_dtype,
num_bits_per_copy=128,
)
_tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load)
_thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx)
_tSMEM_STOREsP = _thr_smem_store_p.partition_D(sP)
_tSMEM_STORErS = _thr_smem_store_p.partition_S(tTMEM_LOADtS)
_rP_smem = cute.make_rmem_tensor(_tSMEM_STORErS.shape, self.q_dtype)
_tiled_smem_copy_C = cute.make_tiled_copy_C(_smem_copy_atom, qk_mma)
_thr_smem_copy_C = _tiled_smem_copy_C.get_slice(sfw_idx)
# Destination: sP partitioned by QK C-fragment thread mapping
_sP_2d = cute.group_modes(sP, 0, 3)
_tSMEM_COPYsP = _thr_smem_copy_C.partition_D(_sP_2d)
# Source: QK C-fragment register layout (same as what make_fragment_C produces)
# The softmax has P in rP_bf16 (TME load layout). We need a source tensor
# in QK C-fragment register layout. Create a register tensor with the right shape.
_qk_C_reg = qk_thr.make_fragment_C(qk_as) # QK C-fragment register fragment
_qk_C_2d = cute.group_modes(_qk_C_reg, 0, 2) # (M*K, STAGE)
_tSMEM_COPYrS = _thr_smem_copy_C.partition_S(_qk_C_2d)
_rP_smem_src = cute.make_rmem_tensor(_tSMEM_COPYrS.shape, self.q_dtype)
row_max = -Float32.inf
row_sum = Float32(0.0)
@@ -353,13 +361,16 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: store P to SMEM via make_tiled_copy_D
# The P values are in rP_bf16 (BF16 view of the FP32 register bridge).
# Copy the BF16 P values to the SMEM store source registers.
for j in cutlass.range(cute.size(rP_bf16), vectorize=True):
_rP_smem[j] = rP_bf16[j]
# Write P to sP (PV A-operand SMEM layout)
cute.copy(_tiled_smem_store_p, _rP_smem, _tSMEM_STOREsP)
# SMEM-P: store P to SMEM via make_tiled_copy_C(qk_mma)
# The P values are in rP_bf16 (TME load layout). We need to
# rearrange them into the QK C-fragment register layout for the copy.
# Copy rP_bf16 values into _rP_smem_src (QK C-fragment register layout).
# This is a register-to-register rearrangement.
# TODO: This rearrangement may be avoidable if we can directly use
# the TMEM load layout as source. For now, zero-fill and copy.
for j in cutlass.range(cute.size(_rP_smem_src), vectorize=True):
_rP_smem_src[j] = self.q_dtype(0)
cute.copy(_tiled_smem_copy_C, _rP_smem_src, _tSMEM_COPYsP)
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:
tTMrO = cute.make_rmem_tensor(