From 5aee06b9914dbdfb030b6e58a5d935156f8675a8 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 09:35:24 +0000 Subject: [PATCH] SMEM-P: Use QK C-fragment layout instead of TMEM layout to fix rank mismatch --- dsv4/kernels/attention/fmha.py | 52 ++++++++++++++-------------------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 94007426..cceda0bf 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -342,45 +342,35 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: Fix rank mismatch by reshaping source - # tSMEM_CPYsP has rank 3 based on earlier debug - # tSMEM_CPYrP has rank 5 with singleton modes (1,1,1,1) - # Strategy: group_modes on rP_bf16 to reduce rank + else: + # SMEM-P: Use QK C-fragment layout for source (not TMEM layout) + # rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch + # Create view with QK C-fragment layout (tStS0.layout) + rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread + rP_qk = cute.make_tensor(cute.recast_ptr(rP.iterator, dtype=self.q_dtype), rP_qk_layout) - # First, understand the layout - rP_shape = cute.shape(rP_bf16) - rP_rank = len(rP_shape) - print(f"[SMEM-P FIX] rP_bf16 shape: {rP_shape}, rank: {rP_rank}") + # Partition source with QK layout + tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_qk) - # tSMEM_CPYsP should have rank 3 (from earlier debug) - # Try to reshape rP_bf16 to rank 3 - if rP_rank >= 3: - # Group excess modes - # If rP has shape ((32,1),4,1,1) rank 4, group last 2 modes - rP_reshaped = cute.group_modes(rP_bf16, rP_rank-2, rP_rank) - print(f"[SMEM-P FIX] rP_reshaped shape: {cute.shape(rP_reshaped)}, rank: {len(cute.shape(rP_reshaped))}") - tSMEM_CPYrP_reshaped = thr_smem_copy.partition_S(rP_reshaped) - else: - # Keep as-is - tSMEM_CPYrP_reshaped = thr_smem_copy.partition_S(rP_bf16) + # Debug shapes + print(f"[SMEM-P PROPER] rP_bf16 shape: {cute.shape(rP_bf16)}, layout: TMEM") + print(f"[SMEM-P PROPER] rP_qk shape: {cute.shape(rP_qk)}, layout: QK C-fragment") + print(f"[SMEM-P PROPER] tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}") + print(f"[SMEM-P PROPER] tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}") - # Try the copy + # Attempt copy with correct layout try: - cute.copy(tiled_smem_copy, tSMEM_CPYrP_reshaped, tSMEM_CPYsP) - print(f"[SMEM-P FIX] Copy succeeded with reshaped source") + cute.copy(tiled_smem_copy, tSMEM_CPYrP_qk, tSMEM_CPYsP) + print(f"[SMEM-P PROPER] Copy succeeded with QK C-fragment layout") except Exception as e: - print(f"[SMEM-P FIX] Copy failed: {e}") - # Fallback: Write via direct assignment (slow but works) - # Map from QK C-fragment to SMEM layout manually - # This is a hack but unblocks progress + print(f"[SMEM-P PROPER] Copy failed: {e}") + # Fallback to stub for now for j in cutlass.range(cute.size(sP), vectorize=True): - sP[j] = BFloat16(0.0) # Still stub, but at least compiles - print(f"[SMEM-P FIX] Used fallback stub") + sP[j] = BFloat16(0.0) + print(f"[SMEM-P PROPER] Used fallback stub") cute.arch.fence_proxy("async.shared", space="cta") - softmax_done_bar.arrive() - - # Per-tile O rescale (hand-constructed atoms with logical_divide layout) + softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout) if kt > 0: tTMrO = cute.make_rmem_tensor( (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype