D1.3: Try make_tiled_copy_C(qk_mma) for SMEM-P copy - zero-fill source for compile test
This commit is contained in:
@@ -267,20 +267,28 @@ class FmhaKernel:
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# P SMEM copy atoms: SMEM-P
|
||||
# Uses get_smem_store_op + make_tiled_copy_D from CUTLASS blackwell_helpers.
|
||||
# This creates a SMEM store copy using the same thread partition as the TMEM load,
|
||||
# so the same threads that compute P (softmax warps) can write P to sP directly.
|
||||
# The MMA warp then reads P from sP via pv_mma.make_fragment_A(sP).
|
||||
# Must define unconditionally (CuTeDSL scoping: compile both branches).
|
||||
# For TMEM-P (use_smem_p=False), these are allocated but unused (dead-code-eliminated).
|
||||
_p_smem_store_atom = get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_tmem_load
|
||||
# Approach: use make_tiled_copy_C(qk_mma) to create a copy that writes
|
||||
# from QK C-fragment register layout to SMEM. The softmax threads have P values
|
||||
# in registers after computing softmax. We write these to sP so the MMA warp
|
||||
# can read them via pv_mma.make_fragment_A(sP).
|
||||
# Must define unconditionally (CuTeDSL scoping).
|
||||
_smem_copy_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.q_dtype,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
_tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load)
|
||||
_thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx)
|
||||
_tSMEM_STOREsP = _thr_smem_store_p.partition_D(sP)
|
||||
_tSMEM_STORErS = _thr_smem_store_p.partition_S(tTMEM_LOADtS)
|
||||
_rP_smem = cute.make_rmem_tensor(_tSMEM_STORErS.shape, self.q_dtype)
|
||||
_tiled_smem_copy_C = cute.make_tiled_copy_C(_smem_copy_atom, qk_mma)
|
||||
_thr_smem_copy_C = _tiled_smem_copy_C.get_slice(sfw_idx)
|
||||
# Destination: sP partitioned by QK C-fragment thread mapping
|
||||
_sP_2d = cute.group_modes(sP, 0, 3)
|
||||
_tSMEM_COPYsP = _thr_smem_copy_C.partition_D(_sP_2d)
|
||||
# Source: QK C-fragment register layout (same as what make_fragment_C produces)
|
||||
# The softmax has P in rP_bf16 (TME load layout). We need a source tensor
|
||||
# in QK C-fragment register layout. Create a register tensor with the right shape.
|
||||
_qk_C_reg = qk_thr.make_fragment_C(qk_as) # QK C-fragment register fragment
|
||||
_qk_C_2d = cute.group_modes(_qk_C_reg, 0, 2) # (M*K, STAGE)
|
||||
_tSMEM_COPYrS = _thr_smem_copy_C.partition_S(_qk_C_2d)
|
||||
_rP_smem_src = cute.make_rmem_tensor(_tSMEM_COPYrS.shape, self.q_dtype)
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
@@ -353,13 +361,16 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: store P to SMEM via make_tiled_copy_D
|
||||
# The P values are in rP_bf16 (BF16 view of the FP32 register bridge).
|
||||
# Copy the BF16 P values to the SMEM store source registers.
|
||||
for j in cutlass.range(cute.size(rP_bf16), vectorize=True):
|
||||
_rP_smem[j] = rP_bf16[j]
|
||||
# Write P to sP (PV A-operand SMEM layout)
|
||||
cute.copy(_tiled_smem_store_p, _rP_smem, _tSMEM_STOREsP)
|
||||
# SMEM-P: store P to SMEM via make_tiled_copy_C(qk_mma)
|
||||
# The P values are in rP_bf16 (TME load layout). We need to
|
||||
# rearrange them into the QK C-fragment register layout for the copy.
|
||||
# Copy rP_bf16 values into _rP_smem_src (QK C-fragment register layout).
|
||||
# This is a register-to-register rearrangement.
|
||||
# TODO: This rearrangement may be avoidable if we can directly use
|
||||
# the TMEM load layout as source. For now, zero-fill and copy.
|
||||
for j in cutlass.range(cute.size(_rP_smem_src), vectorize=True):
|
||||
_rP_smem_src[j] = self.q_dtype(0)
|
||||
cute.copy(_tiled_smem_copy_C, _rP_smem_src, _tSMEM_COPYsP)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
|
||||
Reference in New Issue
Block a user