D1.3: Direct coordinate-indexed SMEM-P write using tTMEM_LOADcS coords

Each softmax thread writes its P values to sP using the (m,k) coordinates
from tTMEM_LOADcS. The k coordinate is decomposed into (k0,k1,k2) to
match sP's ((128,16),1,(4,2)) layout. CuTeDSL tensor indexing handles
the swizzle automatically. No make_tiled_copy needed.
This commit is contained in:
2026-05-23 23:19:21 +00:00
parent 58b4537741
commit 4b8970d83c

View File

@@ -272,44 +272,13 @@ class FmhaKernel:
tTMEM_STOREcP = thr_store.partition_S(tScP)
# P SMEM copy atoms: SMEM-P
# Uses make_cotiled_copy to create a custom R→S copy where:
# - Thread/value mapping: softmax/TMEM-load ownership (tTMEM_LOADcS)
# - Destination: sP in PV A-operand swizzled SMEM layout
# Per CUTLASS guidance: make_tiled_copy_C/D encode the wrong invariants
# for this transfer. We build a custom TV layout that maps (tid,vid) -> sP addr.
# Must define unconditionally (CuTeDSL scoping: compile both branches).
# Start with scalar BF16 stores (16-bit) — vectorize later once correct.
_r2s_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.q_dtype,
num_bits_per_copy=16, # scalar BF16 — safe, vectorize later
)
# Build atom_layout_tv: (tid, vid) -> sP address
# tTMEM_LOADcS gives (thr_offset, vid) -> (m, k) coordinate
# sP layout gives ((m,k0),0,(k1,k2),0) -> address (with swizzle)
# We compose these to get (tid, vid) -> sP address.
# Use sP_2d (grouped to 2D) for simplicity.
_sP_nostage = sP[(None, None, None, 0)]
_sP_2d = cute.group_modes(_sP_nostage, 0, 3)
_sP_2d_layout = _sP_2d.layout
# Flatten tTMEM_LOADcS to (total_elements,) -> (m, k) coords
_p_coord_layout = cute.flatten(tTMEM_LOADcS.layout)
# Compose: (tid, vid) -> (m, k) via _p_coord_layout, then (m, k) -> addr via sP_2d
# make_cotiled_copy needs atom_layout_tv where the codomain is in the sP address space.
# composition(sP_2d_layout, p_coord_layout) should give this.
_p_tv_layout = cute.composition(_sP_2d_layout, _p_coord_layout)
_tiled_p_r2s = cute.make_cotiled_copy(
_r2s_atom,
_p_tv_layout,
_sP_2d_layout,
)
_thr_p_r2s = _tiled_p_r2s.get_slice(sfw_idx)
_tRS_sP = _thr_p_r2s.partition_D(_sP_2d)
# Source: register tensor in the copy's value order.
# The softmax computes P in rP_bf16 (TME load layout). We retile it
# into the copy's expected value order, or create a new source tensor
# and fill it during softmax.
_rP_store = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype)
# Per CUTLASS guidance: make_tiled_copy_C/D encode wrong invariants.
# Use direct coordinate-indexed write to sP.
# Each softmax thread knows its (m, k) from tTMEM_LOADcS.
# sP is indexed as sP[(m, k%16), 0, ((k//16)%4, k//64), stage].
# CuTeDSL tensor indexing handles the swizzle automatically.
# Must define unconditionally (CuTeDSL scoping).
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
row_max = -Float32.inf
row_sum = Float32(0.0)
@@ -382,13 +351,20 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: store P to SMEM via make_cotiled_copy
# Fill _rP_store with P values (in the copy's value order).
# For now, zero-fill to test compilation. The real P values
# will be filled by remapping from rP_bf16 to _rP_store's order.
for j in cutlass.range(cute.size(_rP_store), vectorize=True):
_rP_store[j] = self.q_dtype(0)
cute.copy(_tiled_p_r2s, _rP_store, _tRS_sP)
# SMEM-P: write P to sP using coordinate-indexed store.
# Each thread knows its (m, k) from tTMEM_LOADcS.
# Index sP at ((m, k%16), 0, ((k//16)%4, k//64), 0).
# CuTeDSL tensor indexing handles the swizzle automatically.
for j0 in range(cute.size(tTMEM_LOADcS, mode=[0])):
for j1 in range(cute.size(tTMEM_LOADcS, mode=[1])):
m_coord = tTMEM_LOADcS[j0, j1, 0, 0, 0]
k_coord = tTMEM_LOADcS[j0, j1, 0, 0, 1]
# Decompose k into sP's sub-coordinates
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
# Write P value to sP (swizzle handled by tensor layout)
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[j0, j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:
tTMrO = cute.make_rmem_tensor(