shit left dangling

This commit is contained in:
2026-05-23 23:58:57 +00:00
parent d092a1743a
commit 24bc318480

View File

@@ -0,0 +1,516 @@
"""
Diagnostic test for SMEM-P: verify the TV layout mapping from TMEM-load to sP.
This test does NOT compute attention. It:
1. Creates the TMEM-load copy and extracts its TV layout
2. Creates sP with PV A-operand layout
3. Builds atom_layout_tv mapping (tid,vid) → sP address
4. Validates the mapping is correct (all 16384 elements covered, no overlaps)
5. Tests the make_cotiled_copy approach
"""
import torch, math
import cutlass, cutlass.cute as cute
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as bh
from cutlass.cute.nvgpu import tcgen05, cpasync
from cutlass import Float32, BFloat16, Int32
import cutlass.torch as ct
import cuda.bindings.driver as cuda
def diag_smem_p_tv():
"""Print TV layout info from TMEM-load copy and sP to build the R→S copy."""
head_dim = 256 # First SMEM-P head dim (not 64 which uses TMEM-P)
s_k = 128
m = 128
pv_n_tile = min(head_dim, 256)
scale_softmax = 1.0 / math.sqrt(head_dim)
scale_softmax_log2 = scale_softmax * math.log2(math.e)
# Create tensors (shapes don't matter much, we just need the layouts)
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Create CuTe tensors
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
# V layout: (pv_n_tile, s_k, 1) for FMHA
v_fmha = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_fmha.unsqueeze(-1)
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c[:, 0:pv_n_tile, :]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c[:, 0:pv_n_tile, :]))
# MMA setup
a_major = cute.LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = cute.LayoutEnum.from_tensor(mK).mma_major_mode()
v_major = cute.LayoutEnum.from_tensor(mV).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
)
pv_a_major = a_major # SMEM-P: A from SMEM
pv_source = tcgen05.OperandSource.SMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, pv_a_major, v_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source
)
# Compute layouts (mirrors FmhaKernel._setup)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
# P SMEM layout (PV A-operand, 1 stage)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
# P TMEM layout (for reference)
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
# QK C-fragment
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) # offset 0 for diagnostics
# PV C-fragment
pv_thr = pv_mma.get_slice(0)
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
# ====== TMEM-load copy (softmax reads S from TMEM) ======
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
print("=== TMEM-Load Copy Info ===")
print(f" layout_tv_tiled: {tiled_tmem_load.layout_tv_tiled}")
print(f" layout_dst_tv_tiled: {tiled_tmem_load.layout_dst_tv_tiled}")
print(f" layout_src_tv_tiled: {tiled_tmem_load.layout_src_tv_tiled}")
print(f" TV shape: {tiled_tmem_load.layout_tv_tiled.shape}")
print(f" TV stride: {tiled_tmem_load.layout_tv_tiled.stride}")
# Softmax thread partition
sfw_idx = 0 # thread 0 in softmax warps
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
# Coordinate identity tensor
cS = cute.make_identity_tensor((128, 128))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
print(f"\n=== Per-Thread Partition (thread 0) ===")
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
print(f" tTMEM_LOADtS layout: {tTMEM_LOADtS.layout}")
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}")
# ====== sP layout (PV A-operand SMEM) ======
# Remove stage dim
sP_stage = p_smem_s.outer
if len(cute.shape(sP_stage)) > 3:
# Try removing the stage dimension
sP_stage = sP_stage # might need slicing
print(f"\n=== sP Layout (PV A-operand SMEM) ===")
print(f" p_smem_s.outer shape: {cute.shape(p_smem_s.outer)}")
print(f" p_smem_s.outer layout: {p_smem_s.outer}")
print(f" p_smem_s.inner (swizzle): {p_smem_s.inner}")
# Flatten sP to 1D for address computation
sP_flat = cute.coalesce(p_smem_s.outer)
print(f" sP_flat shape: {cute.shape(sP_flat)}")
print(f" sP_flat layout: {sP_flat}")
# ====== Build atom_layout_tv ======
# The TV layout of tiled_tmem_load maps (tid, vid) → tStS0 coordinate
# We need to compose with (m, k) → sP address mapping
#
# Approach:
# 1. Get the TV layout: (tid, vid) → flat S/P address in tStS0
# 2. Convert tStS0 flat address to (m, k) coordinate
# 3. Convert (m, k) to sP coordinate
# 4. Convert sP coordinate to sP flat address
#
# The TV layout from tiled_tmem_load is in terms of tStS0's data layout.
# tStS0 has layout ((128,128),1,1):((65536,1),0,0)
# So the flat address IS (m * 65536 + k), which is just m*65536 + k
# But tStS0's actual coordinate mapping is (m, k) where m is row, k is column
# Let's look at what layout_dst_tv_tiled gives us
dst_tv = tiled_tmem_load.layout_dst_tv_tiled
print(f"\n=== Destination TV Layout ===")
print(f" shape: {dst_tv.shape}")
print(f" stride: {dst_tv.stride}")
print(f" size: {cute.size(dst_tv)}")
# The destination TV layout maps (tid, vid) to the destination coordinate space.
# For the TMEM-load, destination is tStS0.
# tStS0 layout: ((128,128),1,1):((65536,1),0,0)
# So (tid, vid) → (addr_m * 65536 + addr_k * 1, 0, 0)
# The first component encodes (m, k) in the 128x128 S matrix
# We need to understand how (tid, vid) maps to (m, k)
# For a 128-thread partition (4 softmax warps × 32 threads/warp):
# Total values: 128 * 128 = 16384
num_softmax_threads = 128 # 4 warps
# Each thread has tTMEM_LOADcS size = 32 * 4 = 128 values (from shape ((32,1),4,1,1))
# The TV layout should have shape (num_threads, values_per_thread)
# From the TMEM-load TV layout:
tv_shape = dst_tv.shape
tv_stride = dst_tv.stride
print(f"\n TV shape breakdown: tid_dim={tv_shape[0] if tv_shape else '?'}, vid_dim={tv_shape[1] if len(tv_shape) > 1 else '?'}")
# Now let's try to build the sP address mapping
# sP outer layout (from diagnostics): ((128,16),1,(4,2),1):(((64,1),0,(16,8192)),0)
# With swizzle S<3,4,3>
# For make_cotiled_copy, we need:
# atom_layout_tv: (tid, vid) → sP flat address
# data_layout: sP flat layout (coalesced)
# Step 1: Extract (m, k) from dst_tv
# dst_tv maps (tid, vid) → tStS0 address
# tStS0 address = m * 65536 + k (from layout ((128,128),1,1):((65536,1),0,0))
# So m = addr // 65536, k = addr % 65536
# But since tStS0 has stride (65536, 1) for the first group ((128, 128)):
# addr = m * 65536 + k
# This gives us the (m, k) mapping implicitly.
# Step 2: Map (m, k) to sP address
# sP coordinate: ((m, k%16), 0, ((k//16)%4, k//64), 0)
# sP layout applies swizzle and stride
# sP flat address = sP.layout(((m, k%16), 0, ((k//16)%4, k//64), 0))
# For make_cotiled_copy, we need:
# atom_layout_tv: (tid, vid) → sP flat address
# We can compute this by composing the TV layout with a mapping from
# tStS0 address to sP address.
# Let me think about this differently.
# We know dst_tv maps (tid, vid) → tStS0 address (a flat integer)
# tStS0 address encodes (m, k) as m*65536 + k
# We need to convert this to an sP address
# Actually, let's use the identity tensor approach.
# cS = make_identity_tensor((128, 128)) maps (m, k) → (m, k)
# tScS = qk_thr.partition_C(cS) maps thread's C-fragment coordinates to (m, k)
# tTMEM_LOADcS = thr_load.partition_D(tScS) maps TMEM-load partition coordinates to (m, k)
# The key: tTMEM_LOADcS already gives us (m, k) per (thread_coord, vid_coord)
# We just need to evaluate ALL (thread, vid) pairs to get the full TV → (m, k) mapping
# Then compose with (m, k) → sP address
# Let's print some coordinate values to verify
print(f"\n=== Coordinate Verification ===")
# For thread 0: tTMEM_LOADcS shape is ((32,1),4,1,1)
# Values should be (m, k) pairs covering some subset of [0,128) x [0,128)
# Let's iterate and print a few
# We need to do this at trace time inside @cute.kernel
# For now, print the layout properties
print(f" tScS shape: {cute.shape(tScS)}")
print(f" tScS layout: {tScS.layout}")
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}")
# Try to understand the TV mapping from the layout properties
# The tTMEM_LOADcS layout: ((32,1),4,1,1):((1@1,0),32@1,0,0)
# This means: for index ((j0,0), j1, 0, 0):
# j0 ranges 0..31, j1 ranges 0..3
# coordinate = (j0 * 1, j1 * 32) = (j0, j1 * 32)
# But wait, the @1 in stride means there's a tmem-specific indexing
# Actually, the stride is ((1@1, 0), 32@1, 0, 0)
# @1 might mean "stride 1 in the second dimension of a 2D coordinate"
# Let me look at it from the TMEM load's perspective:
# Ld32x32bOp with Repetition(32) loads a 32x32 tile of FP32 from TMEM
# Each warp has 32 threads. 4 warps = 128 threads.
# Each thread loads 32 values per invocation, repeated 32 times?
# No - Repetition(32) means 32 repetitions of the base operation
# Ld32x32bOp base: 32 threads each load 32 values from a 32x32 TMEM tile
# Repetition(32): each thread loads 32 * 32 = 1024 values total?
# That's too many. With 128 threads, that's 128 * 1024 = 131072 >> 16384
# Hmm, let me reconsider. The partition_S/D shapes tell the story:
# tTMEM_LOADtS shape: (((32,32),1),4,1,1) - source (TMEM data)
# tTMEM_LOADcS shape: ((32,1),4,1,1) - destination (coordinate)
#
# The (32,1) in cS vs (32,32) in tS:
# The coordinate is 2D (m,k) so (32,1) means 32 coordinates
# The data is FP32 so (32,32) means 32*32 = 1024 values? No that's too many
#
# Actually, looking at the shapes:
# tTMEM_LOADtS: (((32,32),1),4,1,1) - this is the data register layout
# Inner (32,32) = 1024 values per fragment, 4 fragments = 4096 per thread
# 128 threads * 4096 = 524288 ≠ 16384
#
# This doesn't add up. The shape must represent the register layout differently.
# Let me just count: 32*1*4*1*1 = 128 elements per thread (coordinate)
# And 32*32*1*4*1*1 = 4096 elements per thread (data)
# 128 * 128 = 16384 coordinates ✓ (each coord is 2D = 32768 scalars)
# But 128 * 4096 = 524288 data elements ≠ 16384
#
# Something's off. The tTMEM_LOADtS data shape includes the FP32 register layout,
# not element count. The actual element count per thread matches the coordinate count
# because each (m,k) coordinate corresponds to one P value.
print(f"\n=== Element Count Verification ===")
cs_size = cute.size(tTMEM_LOADcS) # should be 128 (coordinates per thread)
ts_size = cute.size(tTMEM_LOADtS) # data elements per thread
print(f" tTMEM_LOADcS size: {cs_size} (coordinates per thread)")
print(f" tTMEM_LOADtS size: {ts_size} (data elements per thread)")
print(f" 128 threads * {cs_size} coords = {128 * cs_size} (should be 16384)")
# ====== Try make_cotiled_copy approach ======
# We need atom_layout_tv: (tid, vid) → sP flat address
# The TMEM-load TV layout gives us (tid, vid) → tStS0 address
# tStS0 address encodes (m, k)
# We need to compose with a (m, k) → sP address mapping
# Step 1: Get the TV layout that maps to the S/P matrix coordinates
# dst_tv = tiled_tmem_load.layout_dst_tv_tiled
# This maps (tid, vid) → coordinate in tStS0
# But tStS0 has a specific layout ((128,128),1,1):((65536,1),0,0)
# So the TV output is in terms of tStS0's flat address space
# Step 2: Build a layout that maps tStS0 flat address → sP flat address
# tStS0 address = m * 65536 + k (from stride)
# We need to map this to the sP address
# sP address = sP.layout(m, k%16, (k//16)%4, k//64) (after swizzle)
# This requires iterating over all (m, k) and computing sP addresses
# which we can do at Python time since both layouts are static
# Let's compute the mapping table
sP_outer = p_smem_s.outer
sP_swizzle = p_smem_s.inner
# Coalesce sP to get a flat 1D layout for address computation
sP_coalesced = cute.coalesce(sP_outer)
print(f"\n=== sP Coalesced Layout ===")
print(f" shape: {cute.shape(sP_coalesced)}")
print(f" layout: {sP_coalesced}")
# Compute (m, k) → sP address mapping for all 16384 elements
# For verification: build the address table
sP_shape = cute.shape(sP_outer)
print(f"\n=== Full sP shape: {sP_shape} ===")
# Let's try a different approach: use composition directly
# We know that:
# 1. The TMEM-load TV layout maps (tid, vid) → tStS0 coordinate
# 2. tStS0's layout has shape (128, 128) in the first group
# 3. We need (128, 128) → sP address
# Build a layout that maps the 128x128 P matrix to sP addresses
# For each (m, k) in [0,128) x [0,128):
# sP index = ((m, k%16), 0, ((k//16)%4, k//64), 0)
# sP address = sP_outer(sP index) (then swizzle applied by CuTe)
# We can build this as a CuTe layout:
# shape = (128, 128), mapping to sP addresses
# For k in [0, 128):
# k0 = k % 16
# k1 = (k // 16) % 4
# k2 = k // 64
# sP_addr = sP_outer((m, k0), 0, (k1, k2), 0)
# But CuTe layouts are affine (linear). The (k%16, (k//16)%4, k//64) decomposition
# is NOT affine unless 128 = 16 * 4 * 2 exactly (which it is!)
# k = k0 + 16*k1 + 64*k2 where k0∈[0,16), k1∈[0,4), k2∈[0,2)
# So k0 = k % 16, k1 = (k//16) % 4, k2 = k//64
# And the sP shape ((128, 16), 1, (4, 2), 1) has:
# mode 0: (128, 16) → stride (64, 1)
# mode 2: (4, 2) → stride (16, 8192)
# So sP_addr(m, k0, k1, k2) = m*64 + k0*1 + k1*16 + k2*8192
# Wait, but sP has swizzle S<3,4,3> which XORs some bits
# The swizzle is applied by CuTe when we use the tensor directly
# For make_cotiled_copy, we need the unswizzled (logical) layout
# because CuTe handles the swizzle at access time
# Actually, let me re-read make_cotiled_copy's requirement:
# atom_layout_tv: (tid, vid) -> data addr
# data_layout: data coord -> data addr
#
# The "data addr" here is the LOGICAL address (before swizzle)
# because CuTe applies swizzle at tensor access time
# So we need the sP outer layout WITHOUT swizzle for data_layout
# But p_smem_s.outer IS the logical layout (swizzle is in p_smem_s.inner)
# So for make_cotiled_copy:
# data_layout = sP_outer (logical, no swizzle)
# atom_layout_tv: (tid, vid) -> sP logical address
# The sP logical layout with shape ((128,16),1,(4,2),1) and strides
# can be coalesced to a 1D layout mapping (m, k) → address
# Let me compute the sP logical layout as a (128, 128) → address mapping
# sP_outer((m, k0), 0, (k1, k2), 0) = m*64 + k0*1 + k1*16 + k2*8192
# If we view this as a function of (m, k) where k = k0 + 16*k1 + 64*k2:
# sP_addr(m, k) = m*64 + (k%16)*1 + ((k//16)%4)*16 + (k//64)*8192
# = m*64 + (k%16) + 16*((k//16)%4) + 8192*(k//64)
# This is NOT a simple affine function of k alone (due to the modular decomposition)
# But it IS affine if we treat k as decomposed into (k0, k1, k2)
# k = k0 + 16*k1 + 64*k2
# sP_addr = 64*m + 1*k0 + 16*k1 + 8192*k2
# So the sP logical layout, viewed as a (128, 128) → address map,
# has stride (64, ?) where the k-stride varies based on k's position
# This is the standard column-major with subtiling
# For make_cotiled_copy, we need to express this as a CuTe layout
# The sP outer layout in its native form IS the correct data_layout
print("\n=== Plan for make_cotiled_copy ===")
print("1. data_layout = coalesce(sP_outer) — flat 1D sP address space")
print("2. atom_layout_tv = compose(tiled_tmem_load.layout_dst_tv_tiled, tStS_to_sP_mapping)")
print("3. tStS_to_sP_mapping: tStS address → sP address")
print(" tStS layout: (128,128) → address, stride=(65536,1)")
print(" sP layout: (128,16,4,2) → address (after coalescing), strides vary")
print("")
print("The problem: tStS has stride 65536 for m, while sP has stride 64 for m")
print("These are fundamentally different address spaces.")
print("We need to: (tid,vid) → (m,k) via identity, then (m,k) → sP_addr via sP layout")
# Actually, the cleaner approach:
# 1. Build identity tensor for (128, 128) P matrix
# 2. The TMEM-load partition + coordinate partition gives us (tid, vid) → (m, k)
# 3. The sP layout gives us (m, k) → sP address
# 4. Compose these to get (tid, vid) → sP address
# The identity tensor cS = make_identity_tensor((128, 128))
# maps (m, k) → (m, k) with layout (128, 128):((1, 128)) — row-major
# Wait, identity tensor layout is just (128, 128) with stride (1, 128)? Let me check.
# Actually, make_identity_tensor creates a tensor where each element holds its coordinate
# The layout is just an enumerate layout: (128, 128):(1, 128) in row-major
# So cS(m, k) = m + 128*k (flat index in row-major order)
# The tScS = qk_thr.partition_C(cS) partitions the identity by QK C-fragment threads
# The tTMEM_LOADcS = thr_load.partition_D(tScS) further partitions by TMEM-load threads
# So tTMEM_LOADcS maps (thread_local_index) → (m, k) coordinate
# The shape is ((32,1),4,1,1) with layout ((1@1,0),32@1,0,0)
# For index ((j0,0), j1, 0, 0):
# value = (j0, j1*32) — thread 0's coordinates
# But we need the FULL (tid, vid) → (m, k) mapping for ALL 128 threads
# The tTMEM_LOADcS is for a SINGLE thread (thread 0)
# The TV layout (layout_dst_tv_tiled) gives us the FULL mapping:
# (tid, vid) → destination coordinate (in tStS0's address space)
# But the destination is tStS0 which has layout ((128,128),1,1):((65536,1),0,0)
# So the TV output is in tStS0's flat address: m*65536 + k
# To convert to sP address, we need to:
# 1. Extract (m, k) from tStS0 flat address: m = addr // 65536, k = addr % 65536
# But this isn't simple because CuTe layouts aren't easily decomposable
# 2. Map (m, k) to sP address: use sP_outer((m, k%16), 0, ((k//16)%4, k//64), 0)
# Let's try using composition directly.
# We have:
# dst_tv: (tid, vid) → tStS0_addr
# tStS0 layout: ((128,128),1,1):((65536,1),0,0)
# We need: (tid, vid) → sP_addr
# Approach: build a "reindex" layout that maps tStS0_addr → sP_addr
# For each possible tStS0_addr, compute the corresponding sP_addr
# But tStS0_addr = m*65536 + k, and we need to decompose back to (m, k)
# With m ∈ [0, 128), k ∈ [0, 128), and stride 65536 for m and 1 for k
# We can build the inverse: given addr, m = addr // 65536, k = addr % 65536
# Then sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
# This is a 16384-element lookup table. We can build it at Python time.
# Let me build the lookup table
print(f"\n=== Building (tid, vid) → sP_addr mapping ===")
# Get the TV layout
tv = dst_tv
tv_shape = tv.shape
tv_stride = tv.stride
# We need to evaluate the TV layout for all (tid, vid) pairs
# and then map the resulting tStS0 address to sP address
# First, let's understand the TV layout structure
# It should have shape (num_threads, values_per_thread)
# num_threads = 128 (4 softmax warps)
# values_per_thread = 128 (32*4 from the coordinate shape)
print(f" TV shape: {tv_shape}")
print(f" TV stride: {tv_stride}")
print(f" Total elements: {cute.size(tv)} (should be 16384)")
# Evaluate the TV layout to get all (tid, vid) → tStS0_addr values
# We can do this by iterating over (tid, vid) and computing the layout value
# CuTe layout evaluation: layout(idx) = sum(idx_i * stride_i)
# For a 2D layout (tid, vid) → addr:
# addr = tid * stride[0] + vid * stride[1] (simplified)
# But the TV layout might have nested shapes/strides
# Let me just try to evaluate it
# Actually, for make_cotiled_copy, I don't need to manually compute the mapping.
# I can use CuTe's composition operation:
# atom_layout_tv = composition(sP_flat_layout, dst_tv)
# This would compose dst_tv: (tid, vid) → tStS0_addr
# with a mapping tStS0_addr → sP_addr
# But I need to build the tStS0_addr → sP_addr mapping first
# Let me build it as a CuTe layout
# tStS0 has layout ((128,128),1,1):((65536,1),0,0)
# Coalescing: (128,128):(65536,1) — so tStS0_addr = m*65536 + k
# The inverse is: m = addr // 65536, k = addr % 65536
# sP has layout ((128,16),1,(4,2),1):((64,1),0,(16,8192),0) (before swizzle)
# Coalescing: ((128,16,4,2),1,1,1):((64,1,16,8192),0,0,0) → flat
# sP_addr = 64*m + 1*k0 + 16*k1 + 8192*k2
# = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
# To build tStS0_addr → sP_addr as a CuTe layout:
# We need a layout with shape matching tStS0's domain (16384 elements)
# But CuTe layouts are affine, and the mapping is NOT affine
# (due to the modular decomposition of k)
# So we CAN'T use composition directly. We need to build a custom layout.
# Alternative: build the sP layout as a (128, 128) → addr layout directly
# sP_addr(m, k) = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
# This can be expressed as:
# sP_addr = 64*m + f(k) where f(k) = (k%16) + 16*((k//16)%4) + 8192*(k//64)
# f(k) IS representable as a CuTe layout because 16*4*2 = 128
# k can be decomposed as (k0, k1, k2) where k0=k%16, k1=(k//16)%4, k2=k//64
# f(k) = k0 + 16*k1 + 8192*k2 = (k0, k1, k2) → (1, 16, 8192) → addr
# So the full sP layout in (m, k) coordinates:
# (m, k) → (m, k0, k1, k2) → (64*m + k0 + 16*k1 + 8192*k2)
# = (m, k0, k1, k2) → (64, 1, 16, 8192) → sP_addr
# And the tStS0 layout:
# (m, k) → (65536*m + k) → tStS0_addr
# The mapping tStS0_addr → (m, k) is the INVERSE of the tStS0 layout:
# (65536, 1) → (m, k) — this is what left_inverse gives
# Then (m, k) → sP_addr via the sP layout
# So the full chain