D1.3: Use full sP (4D) for make_tiled_copy_D partition
This commit is contained in:
@@ -278,8 +278,9 @@ class FmhaKernel:
|
||||
_tiled_smem_store_p = cute.make_tiled_copy_D(_p_smem_store_atom, tiled_tmem_load)
|
||||
_thr_smem_store_p = _tiled_smem_store_p.get_slice(sfw_idx)
|
||||
# Partition sP for the SMEM store (destination)
|
||||
_sP_2d = cute.group_modes(sP, 0, 3)
|
||||
_tSMEM_STOREsP = _thr_smem_store_p.partition_D(_sP_2d)
|
||||
# sP has shape ((128,16),1,(4,2),1) — 4D. The tiler expects rank >= 3.
|
||||
# Use the full sP tensor (not sP_2d) for proper tiling.
|
||||
_tSMEM_STOREsP = _thr_smem_store_p.partition_D(sP)
|
||||
# Create a source register tensor for the SMEM store
|
||||
# The source layout comes from the TMEM load's source partition (tTMEM_LOADtS)
|
||||
# But we need a BF16 view of P, not the FP32 S view.
|
||||
|
||||
Reference in New Issue
Block a user