From f6b43227e5bf6af8b154e252fdacace6184175bb Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 09:21:13 +0000 Subject: [PATCH] Fix SMEM-P copy rank mismatch (use rP_bf16 directly instead of group_modes) --- dsv4/kernels/attention/fmha.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 95b56f7c..79af5205 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -344,9 +344,8 @@ class FmhaKernel: else: # SMEM-P: write P to SMEM via tiled_smem_copy # rP_bf16 contains P values in QK C-fragment layout (BF16) - # Flatten to 2D for copy operation - rP_bf16_2d = cute.group_modes(rP_bf16, 0, 2) - tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16_2d) + # Use rP_bf16 directly (already in correct layout for QK C-fragment) + tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16) cute.copy(tiled_smem_copy, tSMEM_CPYrP, tSMEM_CPYsP) cute.arch.fence_proxy("async.shared", space="cta") softmax_done_bar.arrive()