feat: SMEM-P using make_tiled_copy_tv + logical sP view
This commit is contained in:
@@ -366,37 +366,34 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: write P to sP using make_cotiled_copy.
|
||||
# Build atom_layout_tv: (tid, vid) -> sP_addr.
|
||||
# Each softmax thread (128 total) owns one row of the 128×128 P matrix.
|
||||
# Within a row, 128 values decompose as (k0, k1, k2) = (16, 4, 2)
|
||||
# matching sP's subtiled layout strides (1, 16, 8192).
|
||||
# SMEM-P: write P to sP using make_tiled_copy_tv.
|
||||
# Strategy: Create a R→S copy whose TV layout maps softmax
|
||||
# thread/value ownership to sP's address space.
|
||||
#
|
||||
# Key assumption: 128 softmax threads each own one row of the
|
||||
# 128×128 P matrix. Within each row, values are in sP's
|
||||
# subtiled format: (16, 4, 2) with strides (1, 16, 8192).
|
||||
# Thread stride = 64 (one row in sP's layout).
|
||||
#
|
||||
# atom_layout_tv(tid, k0, k1, k2) = 64*tid + k0 + 16*k1 + 8192*k2
|
||||
# sP_addr(m, k0, k1, k2) = 64*m + k0 + 16*k1 + 8192*k2
|
||||
# These are the SAME function — tid maps to m naturally.
|
||||
# Using make_tiled_copy_tv with thread layout and value layout:
|
||||
# thr_layout: (128,) with stride (1,) — 128 threads
|
||||
# val_layout: (16, 4, 2) with stride (1, 16, 8192) — 128 values
|
||||
_smem_p_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.q_dtype,
|
||||
num_bits_per_copy=16,
|
||||
)
|
||||
_smem_p_atom_layout_tv = cute.make_layout(
|
||||
(128, (16, 4, 2)),
|
||||
stride=(64, (1, 16, 8192)),
|
||||
)
|
||||
# Build sP data layout (coalesced, no swizzle, no stage dim)
|
||||
# sP_outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
|
||||
# Coalesced: ((128,16,4,2),(64,1,16,8192))
|
||||
_smem_p_data_layout = cute.make_layout(
|
||||
(128, 16, 4, 2),
|
||||
stride=(64, 1, 16, 8192),
|
||||
)
|
||||
_tiled_smem_p = cute.make_cotiled_copy(
|
||||
_smem_p_atom, _smem_p_atom_layout_tv, _smem_p_data_layout,
|
||||
_smem_p_thr_layout = cute.make_layout((128,), stride=(1,))
|
||||
_smem_p_val_layout = cute.make_layout((16, 4, 2), stride=(1, 16, 8192))
|
||||
_tiled_smem_p = cute.make_tiled_copy_tv(
|
||||
_smem_p_atom, _smem_p_thr_layout, _smem_p_val_layout,
|
||||
)
|
||||
_thr_smem_p = _tiled_smem_p.get_slice(sfw_idx)
|
||||
_tRS_sP = _thr_smem_p.partition_D(_sP_nostage)
|
||||
# Create a logical (non-swizzled) view of sP for partitioning
|
||||
_sP_logical = cute.make_tensor(_sP_nostage.iterator, cute.make_layout(
|
||||
(128, 16, 4, 2), stride=(64, 1, 16, 8192)
|
||||
))
|
||||
_tRS_sP = _thr_smem_p.partition_D(_sP_logical)
|
||||
_tRS_rP = _tiled_smem_p.retile(rP_bf16)
|
||||
cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP)
|
||||
cute.arch.fence("shared", space="cta")
|
||||
|
||||
Reference in New Issue
Block a user