From bafcfa658faef2d49381219be198553eb28f04d9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 22:28:12 +0000 Subject: [PATCH] D1.3: Define SMEM-P copy atoms unconditionally (CuTeDSL scoping) --- dsv4/kernels/attention/fmha.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 52a5d6ce..99c47ff7 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -271,26 +271,16 @@ class FmhaKernel: # This creates a SMEM store copy using the same thread partition as the TMEM load, # so the same threads that compute P (softmax warps) can write P to sP directly. # The MMA warp then reads P from sP via pv_mma.make_fragment_A(sP). - if self.use_smem_p: - _p_smem_store_atom = get_smem_store_op( - self.c_layout, self.c_dtype, self.acc_dtype, tiled_tmem_load - ) - _tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load) - _thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx) - # Partition sP for the SMEM store (destination) - # sP has shape ((128,16),1,(4,2),1) — 4D. The tiler expects rank >= 3. - # Use the full sP tensor (not sP_2d) for proper tiling. - _tSMEM_STOREsP = _thr_smem_store_p.partition_D(sP) - # Create a source register tensor for the SMEM store - # The source layout comes from the TMEM load's source partition (tTMEM_LOADtS) - # But we need a BF16 view of P, not the FP32 S view. - # The softmax writes P values (BF16) via the register bridge rP_bf16, - # which shares layout with rS (tTMEM_LOADrS layout). - # make_tiled_copy_D uses the same thread partition, so the source partition - # should match the TMEM load source partition. - _tSMEM_STORErS = _thr_smem_store_p.partition_S(tTMEM_LOADtS) - # But rS is FP32. We need BF16. Create BF16 view of rP with same layout. - _rP_smem = cute.make_rmem_tensor(_tSMEM_STORErS.shape, self.q_dtype) + # Must define unconditionally (CuTeDSL scoping: compile both branches). + # For TMEM-P (use_smem_p=False), these are allocated but unused (dead-code-eliminated). + _p_smem_store_atom = get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_tmem_load + ) + _tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load) + _thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx) + _tSMEM_STOREsP = _thr_smem_store_p.partition_D(sP) + _tSMEM_STORErS = _thr_smem_store_p.partition_S(tTMEM_LOADtS) + _rP_smem = cute.make_rmem_tensor(_tSMEM_STORErS.shape, self.q_dtype) row_max = -Float32.inf row_sum = Float32(0.0)