feat: SMEM-P with make_tiled_copy_tv + manual fill
This commit is contained in:
@@ -375,9 +375,8 @@ class FmhaKernel:
|
||||
# subtiled format: (16, 4, 2) with strides (1, 16, 8192).
|
||||
# Thread stride = 64 (one row in sP's layout).
|
||||
#
|
||||
# Using make_tiled_copy_tv with thread layout and value layout:
|
||||
# thr_layout: (128,) with stride (1,) — 128 threads
|
||||
# val_layout: (16, 4, 2) with stride (1, 16, 8192) — 128 values
|
||||
# We create the source register tensor in the copy's value order
|
||||
# and fill it from rP_bf16 during the softmax loop.
|
||||
_smem_p_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.q_dtype,
|
||||
@@ -394,7 +393,24 @@ class FmhaKernel:
|
||||
(128, 16, 4, 2), stride=(64, 1, 16, 8192)
|
||||
))
|
||||
_tRS_sP = _thr_smem_p.partition_D(_sP_logical)
|
||||
_tRS_rP = _tiled_smem_p.retile(rP_bf16)
|
||||
# Create source register tensor in copy's value order
|
||||
_tRS_rP = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype)
|
||||
# Fill _tRS_rP from rP_bf16 using coordinate-based mapping
|
||||
# Each thread owns row sfw_idx of the P matrix.
|
||||
# rP_bf16 is in TMEM-load register order: ((32,1),4,1,1)
|
||||
# _tRS_rP is in copy value order: (16, 4, 2) per thread
|
||||
# For thread i, rP_bf16[(j0,0),j1,0,0] corresponds to P[i, j0+32*j1]
|
||||
# And _tRS_rP[k0,k1,k2] corresponds to P[i, k0+16*k1+64*k2]
|
||||
# So k0+16*k1+64*k2 = j0+32*j1, which means:
|
||||
# j0 = (k0+16*k1+64*k2) % 32
|
||||
# j1 = (k0+16*k1+64*k2) // 32
|
||||
for k0 in range(16):
|
||||
for k1 in range(4):
|
||||
for k2 in range(2):
|
||||
flat_k = k0 + 16 * k1 + 64 * k2 # k index in [0,128)
|
||||
j0 = flat_k % 32
|
||||
j1 = flat_k // 32
|
||||
_tRS_rP[k0, k1, k2] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP)
|
||||
cute.arch.fence("shared", space="cta")
|
||||
if kt > 0:
|
||||
|
||||
Reference in New Issue
Block a user