From 5be5d42e94e4c2b3c7dfab1dff2788f308850029 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 23:24:54 +0000 Subject: [PATCH] D1.3: Compute (m,k) directly from thread mapping instead of identity tensor --- dsv4/kernels/attention/fmha.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 2b3bfa5b..efa43e14 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -351,18 +351,32 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: write P to sP using coordinate-indexed store. + # SMEM-P: write P to sP using direct (m,k) computation. + # Instead of looking up coordinates from identity tensor, + # compute (m, k) from the TMEM load's known thread mapping. + # The Ld32x32bOp(Repetition(32)) distributes elements as: + # Each warp (32 threads) handles 32 rows of the S matrix. + # 4 softmax warps (warps 0-3) handle all 128 rows. + # Within a warp, each thread handles 1 row and 4 column groups. + # Thread sfw_idx: warp_id = sfw_idx // 32, lane = sfw_idx % 32 + # Row: lane + 32 * warp_id + # Columns: 4 fragments of 32 elements each. + # Fragment j1: columns [j1*32, (j1+1)*32) + # Within fragment: column = j0 + j1 * 32 + _warp_id = sfw_idx // 32 + _lane = sfw_idx % 32 + _row = _lane + 32 * _warp_id for j0 in range(32): for j1 in range(4): - coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] - m_coord = coord[0] - k_coord = coord[1] - # DEBUG: print first 8 coords from thread 0 - if sfw_idx == 0 and kt == 0 and j0 < 2 and j1 < 2: - print(f"[SMEM-P] j0={j0} j1={j1} m={m_coord} k={k_coord} P={rP_bf16[(j0, 0), j1, 0, 0]}") + k_coord = j0 + 32 * j1 # column index + m_coord = _row # row index + # Compute sP sub-coordinates k0 = k_coord % 16 k1 = (k_coord // 16) % 4 k2 = k_coord // 64 + # rP_bf16 uses the same indexing as tTMEM_LOADrS + # which has the same layout as tTMEM_LOADcS. + # rP_bf16[(j0, 0), j1, 0, 0] should give P at (m, k). _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: