From ed71f25903d82309772e231a0d07a655d577a2f4 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 20:58:06 +0000 Subject: [PATCH] D1.3: Revert to d1.3-pre-sm100-helpers baseline for testing --- dsv4/kernels/attention/fmha.py | 100 +++++++++++++++++++++++++++++---- 1 file changed, 89 insertions(+), 11 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 961c46c6..49522eb0 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -257,10 +257,19 @@ class FmhaKernel: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) - # P SMEM copy atoms: SMEM-P - # NOTE: make_tiled_copy_C fails (incompatible QK C-fragment vs PV A-operand layouts). - # SMEM-P proper copy is TBD. For now, SMEM-P path zero-fills sP. - # The TMEM-P path (hd<=64) works correctly without SMEM-P. + # P SMEM copy atoms: SMEM-P (always defined, only used when use_smem_p=True) + # Uses make_tiled_copy_C to partition threads by QK MMA's C-fragment layout. + # Softmax warps have P values in QK C-fragment layout (same as rP_bf16). + # This copy writes those values to sP which has PV A-operand SMEM layout. + smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=128, + ) + tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) + thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx) + sP_2d = cute.group_modes(sP, 0, 3) + tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) # destination (SMEM) row_max = -Float32.inf row_sum = Float32(0.0) @@ -333,10 +342,32 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: zero-fill sP (proper SMEM-P copy TBD) - # The TMEM-P path works for hd<=64. SMEM-P needs layout-aware copy. - for j in cutlass.range(cute.size(sP), vectorize=True): - sP[j] = self.q_dtype(0) + # SMEM-P: Use QK C-fragment layout for source (not TMEM layout) + # rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch + # Create view with QK C-fragment layout (tStS0.layout) + rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread + rP_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout) + + # Partition source with QK layout + tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_qk) + + # Debug shapes + print(f"[SMEM-P PROPER] rP_bf16 shape: {cute.shape(rP_bf16)}, layout: TMEM") + print(f"[SMEM-P PROPER] rP_qk shape: {cute.shape(rP_qk)}, layout: QK C-fragment") + print(f"[SMEM-P PROPER] tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}") + print(f"[SMEM-P PROPER] tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}") + + # Attempt copy with correct layout + try: + cute.copy(tiled_smem_copy, tSMEM_CPYrP_qk, tSMEM_CPYsP) + print(f"[SMEM-P PROPER] Copy succeeded with QK C-fragment layout") + except Exception as e: + print(f"[SMEM-P PROPER] Copy failed: {e}") + # Fallback to stub for now + for j in cutlass.range(cute.size(sP), vectorize=True): + sP[j] = BFloat16(0.0) + print(f"[SMEM-P PROPER] Used fallback stub") + cute.arch.fence_proxy("async.shared", space="cta") softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout) if kt > 0: @@ -369,9 +400,56 @@ class FmhaKernel: # Wait for MMA's PV[N-1] to commit before reading O. final_o_bar.arrive_and_wait() - # === DIAGNOSTIC: Test epilogue_tma_store WITHOUT any round-trips === - # If get_tmem_load_op reads O correctly from TMEM, this should give cos 0.9999 - # (un-normalized, just raw PV sum). Then we can add normalization back. + # === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout === + tTMrO_noop = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + for i in range(n_corr_tiles): + tTMrO_i_ = tTMrO_noop[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, + tTMEM_LOADtO.layout, + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, + tTMEM_STOREtO.layout, + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + cute.arch.fence_view_async_tmem_store() + + # === Final O normalization: O *= 1/row_sum === + inv_row_sum = Float32(1.0) / row_sum + + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) + + for i in range(n_corr_tiles): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout + ) + + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[j] = tTMrO_i[j] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + + cute.arch.fence_view_async_tmem_store() + + # Epilogue: TMEM → SMEM → GMEM via TMA store. tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) acc_cons_st = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_acc_stage