From a17dca508ddffaf4c8115aa3a8dccb26fd9cf75c Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 23:26:07 +0000 Subject: [PATCH] =?UTF-8?q?D1.3:=20Revert=20to=20zero-fill=20for=20sP=20-?= =?UTF-8?q?=20need=20to=20verify=20sP=E2=86=92PV=20pipeline=20first?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/attention/fmha.py | 30 +++--------------------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index efa43e14..c21b6ab7 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -351,33 +351,9 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: write P to sP using direct (m,k) computation. - # Instead of looking up coordinates from identity tensor, - # compute (m, k) from the TMEM load's known thread mapping. - # The Ld32x32bOp(Repetition(32)) distributes elements as: - # Each warp (32 threads) handles 32 rows of the S matrix. - # 4 softmax warps (warps 0-3) handle all 128 rows. - # Within a warp, each thread handles 1 row and 4 column groups. - # Thread sfw_idx: warp_id = sfw_idx // 32, lane = sfw_idx % 32 - # Row: lane + 32 * warp_id - # Columns: 4 fragments of 32 elements each. - # Fragment j1: columns [j1*32, (j1+1)*32) - # Within fragment: column = j0 + j1 * 32 - _warp_id = sfw_idx // 32 - _lane = sfw_idx % 32 - _row = _lane + 32 * _warp_id - for j0 in range(32): - for j1 in range(4): - k_coord = j0 + 32 * j1 # column index - m_coord = _row # row index - # Compute sP sub-coordinates - k0 = k_coord % 16 - k1 = (k_coord // 16) % 4 - k2 = k_coord // 64 - # rP_bf16 uses the same indexing as tTMEM_LOADrS - # which has the same layout as tTMEM_LOADcS. - # rP_bf16[(j0, 0), j1, 0, 0] should give P at (m, k). - _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] + # SMEM-P: zero-fill sP for now (testing sP→PV pipeline) + for j in cutlass.range(cute.size(sP), vectorize=True): + sP[j] = self.q_dtype(0) cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: tTMrO = cute.make_rmem_tensor(