D1.3: Direct coordinate-indexed SMEM-P write using tTMEM_LOADcS coords
Each softmax thread writes its P values to sP using the (m,k) coordinates from tTMEM_LOADcS. The k coordinate is decomposed into (k0,k1,k2) to match sP's ((128,16),1,(4,2)) layout. CuTeDSL tensor indexing handles the swizzle automatically. No make_tiled_copy needed.
This commit is contained in:
@@ -272,44 +272,13 @@ class FmhaKernel:
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# P SMEM copy atoms: SMEM-P
|
||||
# Uses make_cotiled_copy to create a custom R→S copy where:
|
||||
# - Thread/value mapping: softmax/TMEM-load ownership (tTMEM_LOADcS)
|
||||
# - Destination: sP in PV A-operand swizzled SMEM layout
|
||||
# Per CUTLASS guidance: make_tiled_copy_C/D encode the wrong invariants
|
||||
# for this transfer. We build a custom TV layout that maps (tid,vid) -> sP addr.
|
||||
# Must define unconditionally (CuTeDSL scoping: compile both branches).
|
||||
# Start with scalar BF16 stores (16-bit) — vectorize later once correct.
|
||||
_r2s_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.q_dtype,
|
||||
num_bits_per_copy=16, # scalar BF16 — safe, vectorize later
|
||||
)
|
||||
# Build atom_layout_tv: (tid, vid) -> sP address
|
||||
# tTMEM_LOADcS gives (thr_offset, vid) -> (m, k) coordinate
|
||||
# sP layout gives ((m,k0),0,(k1,k2),0) -> address (with swizzle)
|
||||
# We compose these to get (tid, vid) -> sP address.
|
||||
# Use sP_2d (grouped to 2D) for simplicity.
|
||||
_sP_nostage = sP[(None, None, None, 0)]
|
||||
_sP_2d = cute.group_modes(_sP_nostage, 0, 3)
|
||||
_sP_2d_layout = _sP_2d.layout
|
||||
# Flatten tTMEM_LOADcS to (total_elements,) -> (m, k) coords
|
||||
_p_coord_layout = cute.flatten(tTMEM_LOADcS.layout)
|
||||
# Compose: (tid, vid) -> (m, k) via _p_coord_layout, then (m, k) -> addr via sP_2d
|
||||
# make_cotiled_copy needs atom_layout_tv where the codomain is in the sP address space.
|
||||
# composition(sP_2d_layout, p_coord_layout) should give this.
|
||||
_p_tv_layout = cute.composition(_sP_2d_layout, _p_coord_layout)
|
||||
_tiled_p_r2s = cute.make_cotiled_copy(
|
||||
_r2s_atom,
|
||||
_p_tv_layout,
|
||||
_sP_2d_layout,
|
||||
)
|
||||
_thr_p_r2s = _tiled_p_r2s.get_slice(sfw_idx)
|
||||
_tRS_sP = _thr_p_r2s.partition_D(_sP_2d)
|
||||
# Source: register tensor in the copy's value order.
|
||||
# The softmax computes P in rP_bf16 (TME load layout). We retile it
|
||||
# into the copy's expected value order, or create a new source tensor
|
||||
# and fill it during softmax.
|
||||
_rP_store = cute.make_rmem_tensor(_tRS_sP.shape, self.q_dtype)
|
||||
# Per CUTLASS guidance: make_tiled_copy_C/D encode wrong invariants.
|
||||
# Use direct coordinate-indexed write to sP.
|
||||
# Each softmax thread knows its (m, k) from tTMEM_LOADcS.
|
||||
# sP is indexed as sP[(m, k%16), 0, ((k//16)%4, k//64), stage].
|
||||
# CuTeDSL tensor indexing handles the swizzle automatically.
|
||||
# Must define unconditionally (CuTeDSL scoping).
|
||||
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
@@ -382,13 +351,20 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: store P to SMEM via make_cotiled_copy
|
||||
# Fill _rP_store with P values (in the copy's value order).
|
||||
# For now, zero-fill to test compilation. The real P values
|
||||
# will be filled by remapping from rP_bf16 to _rP_store's order.
|
||||
for j in cutlass.range(cute.size(_rP_store), vectorize=True):
|
||||
_rP_store[j] = self.q_dtype(0)
|
||||
cute.copy(_tiled_p_r2s, _rP_store, _tRS_sP)
|
||||
# SMEM-P: write P to sP using coordinate-indexed store.
|
||||
# Each thread knows its (m, k) from tTMEM_LOADcS.
|
||||
# Index sP at ((m, k%16), 0, ((k//16)%4, k//64), 0).
|
||||
# CuTeDSL tensor indexing handles the swizzle automatically.
|
||||
for j0 in range(cute.size(tTMEM_LOADcS, mode=[0])):
|
||||
for j1 in range(cute.size(tTMEM_LOADcS, mode=[1])):
|
||||
m_coord = tTMEM_LOADcS[j0, j1, 0, 0, 0]
|
||||
k_coord = tTMEM_LOADcS[j0, j1, 0, 0, 1]
|
||||
# Decompose k into sP's sub-coordinates
|
||||
k0 = k_coord % 16
|
||||
k1 = (k_coord // 16) % 4
|
||||
k2 = k_coord // 64
|
||||
# Write P value to sP (swizzle handled by tensor layout)
|
||||
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[j0, j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
if kt > 0:
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
|
||||
Reference in New Issue
Block a user