feat: SMEM-P with make_tiled_copy_tv + partition_S
This commit is contained in:
@@ -366,39 +366,45 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: write P to sP using a TiledCopy derived from PV MMA's
|
||||
# A-operand layout, but with softmax thread mapping.
|
||||
# SMEM-P: write P to sP using make_tiled_copy_tv.
|
||||
#
|
||||
# Strategy: Use the PV MMA's A-operand SMEM layout (which matches sP)
|
||||
# and create a new TiledCopy with softmax thread/value layout.
|
||||
# The softmax threads (128 total) each own one row of P.
|
||||
# Within each row, values are in sP's subtiled format.
|
||||
# The TMEM-load copy partitions 128 softmax threads across the
|
||||
# 128×128 P matrix. We use the SAME thread layout (one thread
|
||||
# per row) but map values to sP's subtiled address space.
|
||||
#
|
||||
# We use pv_mma's make_tiled_copy_A to get the copy atom and tiling,
|
||||
# then override the thread layout for the softmax threads.
|
||||
_smem_p_tiled_copy = utils.sm100.make_tiled_copy_A(
|
||||
# Thread layout: (128,) stride (1,) — 128 threads, one per P row
|
||||
# Value layout: (16, 4, 2) stride (1, 16, 8192) — sP's k-subtiling
|
||||
# This gives atom_layout_tv: (tid, k0, k1, k2) → 64*tid + k0 + 16*k1 + 8192*k2
|
||||
# Which matches sP's logical layout: sP_addr(m, k0, k1, k2) = 64*m + k0 + 16*k1 + 8192*k2
|
||||
_smem_p_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
pv_mma, self.q_dtype,
|
||||
128, # tiler_mn - matches the P matrix tile size
|
||||
self.q_dtype,
|
||||
num_bits_per_copy=16,
|
||||
)
|
||||
# Get the softmax thread's partition
|
||||
_thr_smem_p = _smem_p_tiled_copy.get_slice(sfw_idx)
|
||||
# Create a logical (non-swizzled) view of sP for partitioning
|
||||
_sP_logical = cute.make_tensor(_sP_nostage.iterator, _sP_nostage.layout)
|
||||
_tRS_sP = _thr_smem_p.partition_D(_sP_logical)
|
||||
# Create source register tensor matching the copy's value order
|
||||
_tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype)
|
||||
# Fill _tRS_rP from rP_bf16.
|
||||
# rP_bf16 is in TMEM-load order: ((32,1),4,1,1) with 128 values
|
||||
# _tRS_rP is in copy value order. We need to map between them.
|
||||
# For the copy, each thread should own P[thread_row, :].
|
||||
# rP_bf16[(j0,0),j1,0,0] = P[thread_row, j0+32*j1]
|
||||
# We need to figure out the copy's value order for our thread.
|
||||
# PRINT THE SHAPES to understand the mapping
|
||||
# For now, fill with zeros as a baseline test
|
||||
for v_idx in cutlass.range(cute.size(_tRS_rP), vectorize=True):
|
||||
_tRS_rP[v_idx] = BFloat16(0.0)
|
||||
cute.copy(_smem_p_tiled_copy, _tRS_rP, _tRS_sP)
|
||||
_smem_p_thr_layout = cute.make_layout((128,), stride=(1,))
|
||||
_smem_p_val_layout = cute.make_layout((16, 4, 2), stride=(1, 16, 8192))
|
||||
_tiled_smem_p = cute.make_tiled_copy_tv(
|
||||
_smem_p_atom, _smem_p_thr_layout, _smem_p_val_layout,
|
||||
)
|
||||
_thr_smem_p = _tiled_smem_p.get_slice(sfw_idx)
|
||||
# Create source register tensor in copy's value order
|
||||
# The partition_D gives us the sP partition for this thread
|
||||
# We need a matching source tensor
|
||||
_tAS_sP = _thr_smem_p.partition_S(_sP_nostage)
|
||||
_tAS_rP = cute.make_rmem_tensor(_tAS_sP.shape, self.q_dtype)
|
||||
# Fill _tAS_rP from rP_bf16.
|
||||
# rP_bf16 is in TMEM-load register order: ((32,1),4,1,1)
|
||||
# Each thread owns row sfw_idx of the P matrix.
|
||||
# rP_bf16[(j0,0),j1,0,0] = P[sfw_idx, j0+32*j1]
|
||||
# The copy's value layout maps (k0,k1,k2) to P[sfw_idx, k0+16*k1+64*k2]
|
||||
# So P[sfw_idx, k] where k = k0+16*k1+64*k2 = j0+32*j1
|
||||
# We need to fill _tAS_rP[k0,k1,k2] = rP_bf16[(k%32,0), k//32, 0, 0]
|
||||
# But _tAS_rP's shape is determined by partition_S, which might be
|
||||
# different from (16,4,2). We need to figure out the actual shape.
|
||||
# For now, zero-fill as baseline, then we'll fill properly.
|
||||
for v_idx in cutlass.range(cute.size(_tAS_rP), vectorize=True):
|
||||
_tAS_rP[v_idx] = BFloat16(0.0)
|
||||
cute.copy(_tiled_smem_p, _tAS_rP, _tAS_sP)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
if kt > 0:
|
||||
for i in range(n_corr_tiles):
|
||||
|
||||
Reference in New Issue
Block a user