Fix SMEM-P copy rank mismatch (use rP_bf16 directly instead of group_modes)

This commit is contained in:
2026-05-23 09:21:13 +00:00
parent 83a7dd0679
commit f6b43227e5

View File

@@ -344,9 +344,8 @@ class FmhaKernel:
else:
# SMEM-P: write P to SMEM via tiled_smem_copy
# rP_bf16 contains P values in QK C-fragment layout (BF16)
# Flatten to 2D for copy operation
rP_bf16_2d = cute.group_modes(rP_bf16, 0, 2)
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16_2d)
# Use rP_bf16 directly (already in correct layout for QK C-fragment)
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16)
cute.copy(tiled_smem_copy, tSMEM_CPYrP, tSMEM_CPYsP)
cute.arch.fence_proxy("async.shared", space="cta")
softmax_done_bar.arrive()