Fix SMEM-P copy rank mismatch (use rP_bf16 directly instead of group_modes)
This commit is contained in:
@@ -344,9 +344,8 @@ class FmhaKernel:
|
||||
else:
|
||||
# SMEM-P: write P to SMEM via tiled_smem_copy
|
||||
# rP_bf16 contains P values in QK C-fragment layout (BF16)
|
||||
# Flatten to 2D for copy operation
|
||||
rP_bf16_2d = cute.group_modes(rP_bf16, 0, 2)
|
||||
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16_2d)
|
||||
# Use rP_bf16 directly (already in correct layout for QK C-fragment)
|
||||
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16)
|
||||
cute.copy(tiled_smem_copy, tSMEM_CPYrP, tSMEM_CPYsP)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
Reference in New Issue
Block a user