From e41cf07f509fe0118a6767e614d29ea0de14eeae Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 02:34:50 +0000 Subject: [PATCH 1/2] fix: tTMrO scoping + restore SMEM-P coordinate write --- dsv4/kernels/attention/fmha.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index ae49935b..82b3bec7 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -366,13 +366,22 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: TEMPORARILY zero-fill sP (debugging deadlock). - # The coordinate-indexed write causes a deadlock at hd=256. - # TODO: Fix the SMEM-P write path. + # SMEM-P: write P to sP using coordinate-indexed store. + # Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates. for j0 in range(32): for j1 in range(4): - _sP_nostage[(j0, j1), 0, (0, 0)] = BFloat16(0.0) + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0] + k_coord = coord[1] + k0 = k_coord % 16 + k1 = (k_coord // 16) % 4 + k2 = k_coord // 64 + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] cute.arch.fence_proxy("async.shared", space="cta") + # O rescale register tensor (defined unconditionally for CuTeDSL scoping) + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) if kt > 0: for i in range(n_corr_tiles): tTMrO_i_ = tTMrO[None, i] From af8303ba6449354532714b18ff719230f64a4f63 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 02:36:08 +0000 Subject: [PATCH 2/2] fix: reorder tTMrO definition after tTMEM_LOADcO --- dsv4/kernels/attention/fmha.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 82b3bec7..3f1a70b6 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -291,12 +291,6 @@ class FmhaKernel: row_max = -Float32.inf row_sum = Float32(0.0) - - # Define tTMrO UNCONDITIONALLY (CuTeDSL scoping rule). - # Used for O rescale (kt > 0) and O normalization (after loop). - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) scale_log2 = Float32(self.scale_softmax_log2) # O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale) @@ -323,6 +317,12 @@ class FmhaKernel: tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) n_corr_tiles = self.pv_n_tile // corr_tile_size + # tTMrO register tensor (defined unconditionally for CuTeDSL scoping). + # Used for O rescale (kt > 0) and O normalization (after loop). + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + for kt in range(self.n_kv_tiles): si_handle = s_cons.wait_and_advance()