diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 95b56f7c..79af5205 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -344,9 +344,8 @@ class FmhaKernel: else: # SMEM-P: write P to SMEM via tiled_smem_copy # rP_bf16 contains P values in QK C-fragment layout (BF16) - # Flatten to 2D for copy operation - rP_bf16_2d = cute.group_modes(rP_bf16, 0, 2) - tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16_2d) + # Use rP_bf16 directly (already in correct layout for QK C-fragment) + tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16) cute.copy(tiled_smem_copy, tSMEM_CPYrP, tSMEM_CPYsP) cute.arch.fence_proxy("async.shared", space="cta") softmax_done_bar.arrive()