Add debug prints for SMEM-P partition layouts
This commit is contained in:
@@ -271,8 +271,11 @@ class FmhaKernel:
|
||||
)
|
||||
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
|
||||
thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx)
|
||||
sP_2d = cute.group_modes(sP, 0, 3)
|
||||
tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) # destination (SMEM)
|
||||
# Debug sP shape
|
||||
print(f"[SMEM-P PROPER] sP shape: {cute.shape(sP)} rank: {len(cute.shape(sP))}")
|
||||
# Try without group_modes first
|
||||
tSMEM_CPYsP = thr_smem_copy.partition_D(sP) # destination (SMEM)
|
||||
print(f"[SMEM-P PROPER] After partition_D: tSMEM_CPYsP layout: {tSMEM_CPYsP.layout}")
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
@@ -352,7 +355,9 @@ class FmhaKernel:
|
||||
rP_bf16_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout)
|
||||
|
||||
# Partition source with QK layout
|
||||
print(f"[SMEM-P PROPER] Before partition_S: rP_bf16_qk shape: {cute.shape(rP_bf16_qk)} layout: {rP_bf16_qk.layout}")
|
||||
tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_bf16_qk)
|
||||
print(f"[SMEM-P PROPER] After partition_S: tSMEM_CPYrP_qk layout: {tSMEM_CPYrP_qk.layout}")
|
||||
|
||||
# Debug shapes
|
||||
print(f"[SMEM-P PROPER] rP_bf16 shape: {cute.shape(rP_bf16)}, layout: TMEM")
|
||||
|
||||
Reference in New Issue
Block a user