From 228ec3c638396ab5d42e438ef9c0da1ade4fa5e7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 02:43:42 +0000 Subject: [PATCH] D1.5: Replace broken make_cotiled_copy SMEM-P with coordinate-indexed store --- dsv4/kernels/attention/fmha.py | 40 ++++++++++------------------------ 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index d8c667c6..59354a64 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -371,35 +371,17 @@ class FmhaKernel: # Each softmax thread (128 total) owns one row of the 128×128 P matrix. # Within a row, 128 values decompose as (k0, k1, k2) = (16, 4, 2) # matching sP's subtiled layout strides (1, 16, 8192). - # Thread stride = 64 (one row in sP's layout). - # - # atom_layout_tv(tid, k0, k1, k2) = 64*tid + k0 + 16*k1 + 8192*k2 - # sP_addr(m, k0, k1, k2) = 64*m + k0 + 16*k1 + 8192*k2 - # These are the SAME function — tid maps to m naturally. - _smem_p_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.q_dtype, - num_bits_per_copy=16, - ) - _smem_p_atom_layout_tv = cute.make_layout( - (128, (16, 4, 2)), - stride=(64, (1, 16, 8192)), - ) - # Build sP data layout (coalesced, no swizzle, no stage dim) - # sP_outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0) - # Coalesced: ((128,16,4,2),(64,1,16,8192)) - _smem_p_data_layout = cute.make_layout( - (128, 16, 4, 2), - stride=(64, 1, 16, 8192), - ) - _tiled_smem_p = cute.make_cotiled_copy( - _smem_p_atom, _smem_p_atom_layout_tv, _smem_p_data_layout, - ) - _thr_smem_p = _tiled_smem_p.get_slice(sfw_idx) - _tRS_sP = _thr_smem_p.partition_D(_sP_nostage) - _tRS_rP = _tiled_smem_p.retile(rP_bf16) - cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP) - cute.arch.fence("shared", space="cta") + # SMEM-P: write P to sP using coordinate-indexed store. + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0] + k_coord = coord[1] + k0 = k_coord % 16 + k1 = (k_coord // 16) % 4 + k2 = k_coord // 64 + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] + cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: for i in range(n_corr_tiles): tTMrO_i_ = tTMrO[None, i]