diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 56092241..b0cb49a3 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -366,52 +366,39 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: write P to sP using make_tiled_copy_tv. - # Strategy: Create a R→S copy whose TV layout maps softmax - # thread/value ownership to sP's address space. + # SMEM-P: write P to sP using a TiledCopy derived from PV MMA's + # A-operand layout, but with softmax thread mapping. # - # Key assumption: 128 softmax threads each own one row of the - # 128×128 P matrix. Within each row, values are in sP's - # subtiled format: (16, 4, 2) with strides (1, 16, 8192). - # Thread stride = 64 (one row in sP's layout). + # Strategy: Use the PV MMA's A-operand SMEM layout (which matches sP) + # and create a new TiledCopy with softmax thread/value layout. + # The softmax threads (128 total) each own one row of P. + # Within each row, values are in sP's subtiled format. # - # We create the source register tensor in the copy's value order - # and fill it from rP_bf16 during the softmax loop. - _smem_p_atom = cute.make_copy_atom( + # We use pv_mma's make_tiled_copy_A to get the copy atom and tiling, + # then override the thread layout for the softmax threads. + _smem_p_tiled_copy = utils.sm100.make_tiled_copy_A( cute.nvgpu.CopyUniversalOp(), - self.q_dtype, - num_bits_per_copy=16, + pv_mma, self.q_dtype, + 128, # tiler_mn - matches the P matrix tile size ) - _smem_p_thr_layout = cute.make_layout((128,), stride=(1,)) - _smem_p_val_layout = cute.make_layout((16, 4, 2), stride=(1, 16, 8192)) - _tiled_smem_p = cute.make_tiled_copy_tv( - _smem_p_atom, _smem_p_thr_layout, _smem_p_val_layout, - ) - _thr_smem_p = _tiled_smem_p.get_slice(sfw_idx) + # Get the softmax thread's partition + _thr_smem_p = _smem_p_tiled_copy.get_slice(sfw_idx) # Create a logical (non-swizzled) view of sP for partitioning - _sP_logical = cute.make_tensor(_sP_nostage.iterator, cute.make_layout( - (128, 16, 4, 2), stride=(64, 1, 16, 8192) - )) + _sP_logical = cute.make_tensor(_sP_nostage.iterator, _sP_nostage.layout) _tRS_sP = _thr_smem_p.partition_D(_sP_logical) - # Create source register tensor in copy's value order + # Create source register tensor matching the copy's value order _tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype) - # Fill _tRS_rP from rP_bf16 using coordinate-based mapping - # Each thread owns row sfw_idx of the P matrix. - # rP_bf16 is in TMEM-load register order: ((32,1),4,1,1) - # _tRS_rP is in copy value order: (16, 4, 2) per thread - # For thread i, rP_bf16[(j0,0),j1,0,0] corresponds to P[i, j0+32*j1] - # And _tRS_rP[k0,k1,k2] corresponds to P[i, k0+16*k1+64*k2] - # So k0+16*k1+64*k2 = j0+32*j1, which means: - # j0 = (k0+16*k1+64*k2) % 32 - # j1 = (k0+16*k1+64*k2) // 32 - for k0 in range(16): - for k1 in range(4): - for k2 in range(2): - flat_k = k0 + 16 * k1 + 64 * k2 # k index in [0,128) - j0 = flat_k % 32 - j1 = flat_k // 32 - _tRS_rP[k0, k1, k2] = rP_bf16[(j0, 0), j1, 0, 0] - cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP) + # Fill _tRS_rP from rP_bf16. + # rP_bf16 is in TMEM-load order: ((32,1),4,1,1) with 128 values + # _tRS_rP is in copy value order. We need to map between them. + # For the copy, each thread should own P[thread_row, :]. + # rP_bf16[(j0,0),j1,0,0] = P[thread_row, j0+32*j1] + # We need to figure out the copy's value order for our thread. + # PRINT THE SHAPES to understand the mapping + # For now, fill with zeros as a baseline test + for v_idx in cutlass.range(cute.size(_tRS_rP), vectorize=True): + _tRS_rP[v_idx] = BFloat16(0.0) + cute.copy(_smem_p_tiled_copy, _tRS_rP, _tRS_sP) cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: for i in range(n_corr_tiles):