diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index ff7b0d10..eb228c8c 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -272,44 +272,13 @@ class FmhaKernel: tTMEM_STOREcP = thr_store.partition_S(tScP) # P SMEM copy atoms: SMEM-P - # Uses make_cotiled_copy to create a custom R→S copy where: - # - Thread/value mapping: softmax/TMEM-load ownership (tTMEM_LOADcS) - # - Destination: sP in PV A-operand swizzled SMEM layout - # Per CUTLASS guidance: make_tiled_copy_C/D encode the wrong invariants - # for this transfer. We build a custom TV layout that maps (tid,vid) -> sP addr. - # Must define unconditionally (CuTeDSL scoping: compile both branches). - # Start with scalar BF16 stores (16-bit) — vectorize later once correct. - _r2s_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.q_dtype, - num_bits_per_copy=16, # scalar BF16 — safe, vectorize later - ) - # Build atom_layout_tv: (tid, vid) -> sP address - # tTMEM_LOADcS gives (thr_offset, vid) -> (m, k) coordinate - # sP layout gives ((m,k0),0,(k1,k2),0) -> address (with swizzle) - # We compose these to get (tid, vid) -> sP address. - # Use sP_2d (grouped to 2D) for simplicity. - _sP_nostage = sP[(None, None, None, 0)] - _sP_2d = cute.group_modes(_sP_nostage, 0, 3) - _sP_2d_layout = _sP_2d.layout - # Flatten tTMEM_LOADcS to (total_elements,) -> (m, k) coords - _p_coord_layout = cute.flatten(tTMEM_LOADcS.layout) - # Compose: (tid, vid) -> (m, k) via _p_coord_layout, then (m, k) -> addr via sP_2d - # make_cotiled_copy needs atom_layout_tv where the codomain is in the sP address space. - # composition(sP_2d_layout, p_coord_layout) should give this. - _p_tv_layout = cute.composition(_sP_2d_layout, _p_coord_layout) - _tiled_p_r2s = cute.make_cotiled_copy( - _r2s_atom, - _p_tv_layout, - _sP_2d_layout, - ) - _thr_p_r2s = _tiled_p_r2s.get_slice(sfw_idx) - _tRS_sP = _thr_p_r2s.partition_D(_sP_2d) - # Source: register tensor in the copy's value order. - # The softmax computes P in rP_bf16 (TME load layout). We retile it - # into the copy's expected value order, or create a new source tensor - # and fill it during softmax. - _rP_store = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype) + # Per CUTLASS guidance: make_tiled_copy_C/D encode wrong invariants. + # Use direct coordinate-indexed write to sP. + # Each softmax thread knows its (m, k) from tTMEM_LOADcS. + # sP is indexed as sP[(m, k%16), 0, ((k//16)%4, k//64), stage]. + # CuTeDSL tensor indexing handles the swizzle automatically. + # Must define unconditionally (CuTeDSL scoping). + _sP_nostage = sP[(None, None, None, 0)] # remove stage dim row_max = -Float32.inf row_sum = Float32(0.0) @@ -382,13 +351,20 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: store P to SMEM via make_cotiled_copy - # Fill _rP_store with P values (in the copy's value order). - # For now, zero-fill to test compilation. The real P values - # will be filled by remapping from rP_bf16 to _rP_store's order. - for j in cutlass.range(cute.size(_rP_store), vectorize=True): - _rP_store[j] = self.q_dtype(0) - cute.copy(_tiled_p_r2s, _rP_store, _tRS_sP) + # SMEM-P: write P to sP using coordinate-indexed store. + # Each thread knows its (m, k) from tTMEM_LOADcS. + # Index sP at ((m, k%16), 0, ((k//16)%4, k//64), 0). + # CuTeDSL tensor indexing handles the swizzle automatically. + for j0 in range(cute.size(tTMEM_LOADcS, mode=[0])): + for j1 in range(cute.size(tTMEM_LOADcS, mode=[1])): + m_coord = tTMEM_LOADcS[j0, j1, 0, 0, 0] + k_coord = tTMEM_LOADcS[j0, j1, 0, 0, 1] + # Decompose k into sP's sub-coordinates + k0 = k_coord % 16 + k1 = (k_coord // 16) % 4 + k2 = k_coord // 64 + # Write P value to sP (swizzle handled by tensor layout) + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[j0, j1, 0, 0] cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: tTMrO = cute.make_rmem_tensor(