D1.3: Compute (m,k) directly from thread mapping instead of identity tensor
This commit is contained in:
@@ -351,18 +351,32 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: write P to sP using coordinate-indexed store.
|
||||
# SMEM-P: write P to sP using direct (m,k) computation.
|
||||
# Instead of looking up coordinates from identity tensor,
|
||||
# compute (m, k) from the TMEM load's known thread mapping.
|
||||
# The Ld32x32bOp(Repetition(32)) distributes elements as:
|
||||
# Each warp (32 threads) handles 32 rows of the S matrix.
|
||||
# 4 softmax warps (warps 0-3) handle all 128 rows.
|
||||
# Within a warp, each thread handles 1 row and 4 column groups.
|
||||
# Thread sfw_idx: warp_id = sfw_idx // 32, lane = sfw_idx % 32
|
||||
# Row: lane + 32 * warp_id
|
||||
# Columns: 4 fragments of 32 elements each.
|
||||
# Fragment j1: columns [j1*32, (j1+1)*32)
|
||||
# Within fragment: column = j0 + j1 * 32
|
||||
_warp_id = sfw_idx // 32
|
||||
_lane = sfw_idx % 32
|
||||
_row = _lane + 32 * _warp_id
|
||||
for j0 in range(32):
|
||||
for j1 in range(4):
|
||||
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
|
||||
m_coord = coord[0]
|
||||
k_coord = coord[1]
|
||||
# DEBUG: print first 8 coords from thread 0
|
||||
if sfw_idx == 0 and kt == 0 and j0 < 2 and j1 < 2:
|
||||
print(f"[SMEM-P] j0={j0} j1={j1} m={m_coord} k={k_coord} P={rP_bf16[(j0, 0), j1, 0, 0]}")
|
||||
k_coord = j0 + 32 * j1 # column index
|
||||
m_coord = _row # row index
|
||||
# Compute sP sub-coordinates
|
||||
k0 = k_coord % 16
|
||||
k1 = (k_coord // 16) % 4
|
||||
k2 = k_coord // 64
|
||||
# rP_bf16 uses the same indexing as tTMEM_LOADrS
|
||||
# which has the same layout as tTMEM_LOADcS.
|
||||
# rP_bf16[(j0, 0), j1, 0, 0] should give P at (m, k).
|
||||
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
if kt > 0:
|
||||
|
||||
Reference in New Issue
Block a user