D1.3: Define SMEM-P copy atoms unconditionally (CuTeDSL scoping)
This commit is contained in:
@@ -271,26 +271,16 @@ class FmhaKernel:
|
||||
# This creates a SMEM store copy using the same thread partition as the TMEM load,
|
||||
# so the same threads that compute P (softmax warps) can write P to sP directly.
|
||||
# The MMA warp then reads P from sP via pv_mma.make_fragment_A(sP).
|
||||
if self.use_smem_p:
|
||||
_p_smem_store_atom = get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_tmem_load
|
||||
)
|
||||
_tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load)
|
||||
_thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx)
|
||||
# Partition sP for the SMEM store (destination)
|
||||
# sP has shape ((128,16),1,(4,2),1) — 4D. The tiler expects rank >= 3.
|
||||
# Use the full sP tensor (not sP_2d) for proper tiling.
|
||||
_tSMEM_STOREsP = _thr_smem_store_p.partition_D(sP)
|
||||
# Create a source register tensor for the SMEM store
|
||||
# The source layout comes from the TMEM load's source partition (tTMEM_LOADtS)
|
||||
# But we need a BF16 view of P, not the FP32 S view.
|
||||
# The softmax writes P values (BF16) via the register bridge rP_bf16,
|
||||
# which shares layout with rS (tTMEM_LOADrS layout).
|
||||
# make_tiled_copy_D uses the same thread partition, so the source partition
|
||||
# should match the TMEM load source partition.
|
||||
_tSMEM_STORErS = _thr_smem_store_p.partition_S(tTMEM_LOADtS)
|
||||
# But rS is FP32. We need BF16. Create BF16 view of rP with same layout.
|
||||
_rP_smem = cute.make_rmem_tensor(_tSMEM_STORErS.shape, self.q_dtype)
|
||||
# Must define unconditionally (CuTeDSL scoping: compile both branches).
|
||||
# For TMEM-P (use_smem_p=False), these are allocated but unused (dead-code-eliminated).
|
||||
_p_smem_store_atom = get_smem_store_op(
|
||||
self.c_layout, self.c_dtype, self.acc_dtype, tiled_tmem_load
|
||||
)
|
||||
_tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load)
|
||||
_thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx)
|
||||
_tSMEM_STOREsP = _thr_smem_store_p.partition_D(sP)
|
||||
_tSMEM_STORErS = _thr_smem_store_p.partition_S(tTMEM_LOADtS)
|
||||
_rP_smem = cute.make_rmem_tensor(_tSMEM_STORErS.shape, self.q_dtype)
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
|
||||
Reference in New Issue
Block a user