diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 897a0966..33bce4ce 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -268,9 +268,7 @@ class FmhaKernel: ) tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx) - print(f"[SMEM-P DEBUG] sP shape: {cute.shape(sP)}") - sP_2d = cute.group_modes(sP, 0, 3) - print(f"[SMEM-P DEBUG] sP_2d shape: {cute.shape(sP_2d)}") # flatten to 2D for copy + print(f"[SMEM-P DEBUG] sP_2d_debug shape: {cute.shape(sP_2d_debug)}") # flatten to 2D for copy tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) # destination (SMEM) row_max = -Float32.inf @@ -356,10 +354,10 @@ class FmhaKernel: print(f"[SMEM-P] Approach1 - dst shape: {cute.shape(tSMEM_CPYsP1)} rank: {len(cute.shape(tSMEM_CPYsP1))}") # Approach 2: Try with sP_2d - sP_2d = cute.group_modes(sP, 0, 3) - print(f"[SMEM-P] sP_2d shape: {cute.shape(sP_2d)} rank: {len(cute.shape(sP_2d))}") + sP_2d_debug = cute.group_modes(sP, 0, 3) + print(f"[SMEM-P] sP_2d_debug shape: {cute.shape(sP_2d_debug)} rank: {len(cute.shape(sP_2d))}") tSMEM_CPYrP2 = thr_smem_copy.partition_S(rP_bf16) - tSMEM_CPYsP2 = thr_smem_copy.partition_D(sP_2d) + tSMEM_CPYsP2 = thr_smem_copy.partition_D(sP_2d_debug) print(f"[SMEM-P] Approach2 - src shape: {cute.shape(tSMEM_CPYrP2)} rank: {len(cute.shape(tSMEM_CPYrP2))}") print(f"[SMEM-P] Approach2 - dst shape: {cute.shape(tSMEM_CPYsP2)} rank: {len(cute.shape(tSMEM_CPYsP2))}")