D1.3: Use full sP (4D) for make_tiled_copy_D partition

This commit is contained in:
2026-05-23 22:27:11 +00:00
parent fa2e513168
commit 8d226a6243

View File

@@ -278,8 +278,9 @@ class FmhaKernel:
_tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load)
_thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx)
# Partition sP for the SMEM store (destination)
_sP_2d = cute.group_modes(sP, 0, 3)
_tSMEM_STOREsP = _thr_smem_store_p.partition_D(_sP_2d)
# sP has shape ((128,16),1,(4,2),1) — 4D. The tiler expects rank >= 3.
# Use the full sP tensor (not sP_2d) for proper tiling.
_tSMEM_STOREsP = _thr_smem_store_p.partition_D(sP)
# Create a source register tensor for the SMEM store
# The source layout comes from the TMEM load's source partition (tTMEM_LOADtS)
# But we need a BF16 view of P, not the FP32 S view.