diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index eb228c8c..15089d02 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -352,19 +352,17 @@ class FmhaKernel: cute.arch.fence_view_async_tmem_store() else: # SMEM-P: write P to sP using coordinate-indexed store. - # Each thread knows its (m, k) from tTMEM_LOADcS. - # Index sP at ((m, k%16), 0, ((k//16)%4, k//64), 0). - # CuTeDSL tensor indexing handles the swizzle automatically. - for j0 in range(cute.size(tTMEM_LOADcS, mode=[0])): - for j1 in range(cute.size(tTMEM_LOADcS, mode=[1])): - m_coord = tTMEM_LOADcS[j0, j1, 0, 0, 0] - k_coord = tTMEM_LOADcS[j0, j1, 0, 0, 1] - # Decompose k into sP's sub-coordinates + # tTMEM_LOADcS shape: ((32,1),4,1,1) with layout ((32,1),4,1,1) + # First mode is (32,1) — 32 m-coordinates per fragment, 1 k-slice. + # So indexing: tTMEM_LOADcS[(j0, 0), j1, 0, 0] gives (m, k). + for j0 in range(32): + for j1 in range(4): + m_coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0, 0] + k_coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0, 1] k0 = k_coord % 16 k1 = (k_coord // 16) % 4 k2 = k_coord // 64 - # Write P value to sP (swizzle handled by tensor layout) - _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[j0, j1, 0, 0] + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: tTMrO = cute.make_rmem_tensor(