D1.3: Fix coord extraction - identity tensor stores (m,k) pairs as values
This commit is contained in:
@@ -352,13 +352,16 @@ class FmhaKernel:
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: write P to sP using coordinate-indexed store.
|
||||
# tTMEM_LOADcS shape: ((32,1),4,1,1) with layout ((32,1),4,1,1)
|
||||
# First mode is (32,1) — 32 m-coordinates per fragment, 1 k-slice.
|
||||
# So indexing: tTMEM_LOADcS[(j0, 0), j1, 0, 0] gives (m, k).
|
||||
# tTMEM_LOADcS contains (m, k) coordinates from identity tensor.
|
||||
# Shape: ((32,1),4,1,1) — indexed with 4 indices.
|
||||
# Each element is an (m, k) coordinate pair.
|
||||
# Extract m with .load()[0] and k with .load()[1],
|
||||
# or use indexing tTMEM_LOADcS[...].value[0/1].
|
||||
for j0 in range(32):
|
||||
for j1 in range(4):
|
||||
m_coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0, 0]
|
||||
k_coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0, 1]
|
||||
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
|
||||
m_coord = coord[0]
|
||||
k_coord = coord[1]
|
||||
k0 = k_coord % 16
|
||||
k1 = (k_coord // 16) % 4
|
||||
k2 = k_coord // 64
|
||||
|
||||
Reference in New Issue
Block a user