shit left dangling
This commit is contained in:
516
tests/unit/test_smem_p_diag.py
Normal file
516
tests/unit/test_smem_p_diag.py
Normal file
@@ -0,0 +1,516 @@
|
||||
"""
|
||||
Diagnostic test for SMEM-P: verify the TV layout mapping from TMEM-load to sP.
|
||||
|
||||
This test does NOT compute attention. It:
|
||||
1. Creates the TMEM-load copy and extracts its TV layout
|
||||
2. Creates sP with PV A-operand layout
|
||||
3. Builds atom_layout_tv mapping (tid,vid) → sP address
|
||||
4. Validates the mapping is correct (all 16384 elements covered, no overlaps)
|
||||
5. Tests the make_cotiled_copy approach
|
||||
"""
|
||||
import torch, math
|
||||
import cutlass, cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.utils.blackwell_helpers as bh
|
||||
from cutlass.cute.nvgpu import tcgen05, cpasync
|
||||
from cutlass import Float32, BFloat16, Int32
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
def diag_smem_p_tv():
|
||||
"""Print TV layout info from TMEM-load copy and sP to build the R→S copy."""
|
||||
head_dim = 256 # First SMEM-P head dim (not 64 which uses TMEM-P)
|
||||
s_k = 128
|
||||
m = 128
|
||||
pv_n_tile = min(head_dim, 256)
|
||||
scale_softmax = 1.0 / math.sqrt(head_dim)
|
||||
scale_softmax_log2 = scale_softmax * math.log2(math.e)
|
||||
|
||||
# Create tensors (shapes don't matter much, we just need the layouts)
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
# Create CuTe tensors
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
|
||||
# V layout: (pv_n_tile, s_k, 1) for FMHA
|
||||
v_fmha = v[:, 0:pv_n_tile].contiguous()
|
||||
v_kernel = v_fmha.unsqueeze(-1)
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c[:, 0:pv_n_tile, :]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c[:, 0:pv_n_tile, :]))
|
||||
|
||||
# MMA setup
|
||||
a_major = cute.LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = cute.LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
v_major = cute.LayoutEnum.from_tensor(mV).mma_major_mode()
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
|
||||
)
|
||||
pv_a_major = a_major # SMEM-P: A from SMEM
|
||||
pv_source = tcgen05.OperandSource.SMEM
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, pv_a_major, v_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source
|
||||
)
|
||||
|
||||
# Compute layouts (mirrors FmhaKernel._setup)
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
qk_mma_tiler = (128, 128, qk_ik * 4)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
|
||||
# P SMEM layout (PV A-operand, 1 stage)
|
||||
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
|
||||
# P TMEM layout (for reference)
|
||||
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
|
||||
# QK C-fragment
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) # offset 0 for diagnostics
|
||||
|
||||
# PV C-fragment
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
|
||||
# ====== TMEM-load copy (softmax reads S from TMEM) ======
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
|
||||
)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
|
||||
print("=== TMEM-Load Copy Info ===")
|
||||
print(f" layout_tv_tiled: {tiled_tmem_load.layout_tv_tiled}")
|
||||
print(f" layout_dst_tv_tiled: {tiled_tmem_load.layout_dst_tv_tiled}")
|
||||
print(f" layout_src_tv_tiled: {tiled_tmem_load.layout_src_tv_tiled}")
|
||||
print(f" TV shape: {tiled_tmem_load.layout_tv_tiled.shape}")
|
||||
print(f" TV stride: {tiled_tmem_load.layout_tv_tiled.stride}")
|
||||
|
||||
# Softmax thread partition
|
||||
sfw_idx = 0 # thread 0 in softmax warps
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
|
||||
# Coordinate identity tensor
|
||||
cS = cute.make_identity_tensor((128, 128))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
print(f"\n=== Per-Thread Partition (thread 0) ===")
|
||||
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
|
||||
print(f" tTMEM_LOADtS layout: {tTMEM_LOADtS.layout}")
|
||||
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
|
||||
print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}")
|
||||
|
||||
# ====== sP layout (PV A-operand SMEM) ======
|
||||
# Remove stage dim
|
||||
sP_stage = p_smem_s.outer
|
||||
if len(cute.shape(sP_stage)) > 3:
|
||||
# Try removing the stage dimension
|
||||
sP_stage = sP_stage # might need slicing
|
||||
|
||||
print(f"\n=== sP Layout (PV A-operand SMEM) ===")
|
||||
print(f" p_smem_s.outer shape: {cute.shape(p_smem_s.outer)}")
|
||||
print(f" p_smem_s.outer layout: {p_smem_s.outer}")
|
||||
print(f" p_smem_s.inner (swizzle): {p_smem_s.inner}")
|
||||
|
||||
# Flatten sP to 1D for address computation
|
||||
sP_flat = cute.coalesce(p_smem_s.outer)
|
||||
print(f" sP_flat shape: {cute.shape(sP_flat)}")
|
||||
print(f" sP_flat layout: {sP_flat}")
|
||||
|
||||
# ====== Build atom_layout_tv ======
|
||||
# The TV layout of tiled_tmem_load maps (tid, vid) → tStS0 coordinate
|
||||
# We need to compose with (m, k) → sP address mapping
|
||||
#
|
||||
# Approach:
|
||||
# 1. Get the TV layout: (tid, vid) → flat S/P address in tStS0
|
||||
# 2. Convert tStS0 flat address to (m, k) coordinate
|
||||
# 3. Convert (m, k) to sP coordinate
|
||||
# 4. Convert sP coordinate to sP flat address
|
||||
#
|
||||
# The TV layout from tiled_tmem_load is in terms of tStS0's data layout.
|
||||
# tStS0 has layout ((128,128),1,1):((65536,1),0,0)
|
||||
# So the flat address IS (m * 65536 + k), which is just m*65536 + k
|
||||
# But tStS0's actual coordinate mapping is (m, k) where m is row, k is column
|
||||
|
||||
# Let's look at what layout_dst_tv_tiled gives us
|
||||
dst_tv = tiled_tmem_load.layout_dst_tv_tiled
|
||||
print(f"\n=== Destination TV Layout ===")
|
||||
print(f" shape: {dst_tv.shape}")
|
||||
print(f" stride: {dst_tv.stride}")
|
||||
print(f" size: {cute.size(dst_tv)}")
|
||||
|
||||
# The destination TV layout maps (tid, vid) to the destination coordinate space.
|
||||
# For the TMEM-load, destination is tStS0.
|
||||
# tStS0 layout: ((128,128),1,1):((65536,1),0,0)
|
||||
# So (tid, vid) → (addr_m * 65536 + addr_k * 1, 0, 0)
|
||||
# The first component encodes (m, k) in the 128x128 S matrix
|
||||
|
||||
# We need to understand how (tid, vid) maps to (m, k)
|
||||
# For a 128-thread partition (4 softmax warps × 32 threads/warp):
|
||||
# Total values: 128 * 128 = 16384
|
||||
|
||||
num_softmax_threads = 128 # 4 warps
|
||||
# Each thread has tTMEM_LOADcS size = 32 * 4 = 128 values (from shape ((32,1),4,1,1))
|
||||
|
||||
# The TV layout should have shape (num_threads, values_per_thread)
|
||||
# From the TMEM-load TV layout:
|
||||
tv_shape = dst_tv.shape
|
||||
tv_stride = dst_tv.stride
|
||||
print(f"\n TV shape breakdown: tid_dim={tv_shape[0] if tv_shape else '?'}, vid_dim={tv_shape[1] if len(tv_shape) > 1 else '?'}")
|
||||
|
||||
# Now let's try to build the sP address mapping
|
||||
# sP outer layout (from diagnostics): ((128,16),1,(4,2),1):(((64,1),0,(16,8192)),0)
|
||||
# With swizzle S<3,4,3>
|
||||
|
||||
# For make_cotiled_copy, we need:
|
||||
# atom_layout_tv: (tid, vid) → sP flat address
|
||||
# data_layout: sP flat layout (coalesced)
|
||||
|
||||
# Step 1: Extract (m, k) from dst_tv
|
||||
# dst_tv maps (tid, vid) → tStS0 address
|
||||
# tStS0 address = m * 65536 + k (from layout ((128,128),1,1):((65536,1),0,0))
|
||||
# So m = addr // 65536, k = addr % 65536
|
||||
# But since tStS0 has stride (65536, 1) for the first group ((128, 128)):
|
||||
# addr = m * 65536 + k
|
||||
# This gives us the (m, k) mapping implicitly.
|
||||
|
||||
# Step 2: Map (m, k) to sP address
|
||||
# sP coordinate: ((m, k%16), 0, ((k//16)%4, k//64), 0)
|
||||
# sP layout applies swizzle and stride
|
||||
# sP flat address = sP.layout(((m, k%16), 0, ((k//16)%4, k//64), 0))
|
||||
|
||||
# For make_cotiled_copy, we need:
|
||||
# atom_layout_tv: (tid, vid) → sP flat address
|
||||
# We can compute this by composing the TV layout with a mapping from
|
||||
# tStS0 address to sP address.
|
||||
|
||||
# Let me think about this differently.
|
||||
# We know dst_tv maps (tid, vid) → tStS0 address (a flat integer)
|
||||
# tStS0 address encodes (m, k) as m*65536 + k
|
||||
# We need to convert this to an sP address
|
||||
|
||||
# Actually, let's use the identity tensor approach.
|
||||
# cS = make_identity_tensor((128, 128)) maps (m, k) → (m, k)
|
||||
# tScS = qk_thr.partition_C(cS) maps thread's C-fragment coordinates to (m, k)
|
||||
# tTMEM_LOADcS = thr_load.partition_D(tScS) maps TMEM-load partition coordinates to (m, k)
|
||||
|
||||
# The key: tTMEM_LOADcS already gives us (m, k) per (thread_coord, vid_coord)
|
||||
# We just need to evaluate ALL (thread, vid) pairs to get the full TV → (m, k) mapping
|
||||
# Then compose with (m, k) → sP address
|
||||
|
||||
# Let's print some coordinate values to verify
|
||||
print(f"\n=== Coordinate Verification ===")
|
||||
# For thread 0: tTMEM_LOADcS shape is ((32,1),4,1,1)
|
||||
# Values should be (m, k) pairs covering some subset of [0,128) x [0,128)
|
||||
# Let's iterate and print a few
|
||||
|
||||
# We need to do this at trace time inside @cute.kernel
|
||||
# For now, print the layout properties
|
||||
print(f" tScS shape: {cute.shape(tScS)}")
|
||||
print(f" tScS layout: {tScS.layout}")
|
||||
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
|
||||
print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}")
|
||||
|
||||
# Try to understand the TV mapping from the layout properties
|
||||
# The tTMEM_LOADcS layout: ((32,1),4,1,1):((1@1,0),32@1,0,0)
|
||||
# This means: for index ((j0,0), j1, 0, 0):
|
||||
# j0 ranges 0..31, j1 ranges 0..3
|
||||
# coordinate = (j0 * 1, j1 * 32) = (j0, j1 * 32)
|
||||
# But wait, the @1 in stride means there's a tmem-specific indexing
|
||||
# Actually, the stride is ((1@1, 0), 32@1, 0, 0)
|
||||
# @1 might mean "stride 1 in the second dimension of a 2D coordinate"
|
||||
|
||||
# Let me look at it from the TMEM load's perspective:
|
||||
# Ld32x32bOp with Repetition(32) loads a 32x32 tile of FP32 from TMEM
|
||||
# Each warp has 32 threads. 4 warps = 128 threads.
|
||||
# Each thread loads 32 values per invocation, repeated 32 times?
|
||||
# No - Repetition(32) means 32 repetitions of the base operation
|
||||
# Ld32x32bOp base: 32 threads each load 32 values from a 32x32 TMEM tile
|
||||
# Repetition(32): each thread loads 32 * 32 = 1024 values total?
|
||||
# That's too many. With 128 threads, that's 128 * 1024 = 131072 >> 16384
|
||||
|
||||
# Hmm, let me reconsider. The partition_S/D shapes tell the story:
|
||||
# tTMEM_LOADtS shape: (((32,32),1),4,1,1) - source (TMEM data)
|
||||
# tTMEM_LOADcS shape: ((32,1),4,1,1) - destination (coordinate)
|
||||
#
|
||||
# The (32,1) in cS vs (32,32) in tS:
|
||||
# The coordinate is 2D (m,k) so (32,1) means 32 coordinates
|
||||
# The data is FP32 so (32,32) means 32*32 = 1024 values? No that's too many
|
||||
#
|
||||
# Actually, looking at the shapes:
|
||||
# tTMEM_LOADtS: (((32,32),1),4,1,1) - this is the data register layout
|
||||
# Inner (32,32) = 1024 values per fragment, 4 fragments = 4096 per thread
|
||||
# 128 threads * 4096 = 524288 ≠ 16384
|
||||
#
|
||||
# This doesn't add up. The shape must represent the register layout differently.
|
||||
# Let me just count: 32*1*4*1*1 = 128 elements per thread (coordinate)
|
||||
# And 32*32*1*4*1*1 = 4096 elements per thread (data)
|
||||
# 128 * 128 = 16384 coordinates ✓ (each coord is 2D = 32768 scalars)
|
||||
# But 128 * 4096 = 524288 data elements ≠ 16384
|
||||
#
|
||||
# Something's off. The tTMEM_LOADtS data shape includes the FP32 register layout,
|
||||
# not element count. The actual element count per thread matches the coordinate count
|
||||
# because each (m,k) coordinate corresponds to one P value.
|
||||
|
||||
print(f"\n=== Element Count Verification ===")
|
||||
cs_size = cute.size(tTMEM_LOADcS) # should be 128 (coordinates per thread)
|
||||
ts_size = cute.size(tTMEM_LOADtS) # data elements per thread
|
||||
print(f" tTMEM_LOADcS size: {cs_size} (coordinates per thread)")
|
||||
print(f" tTMEM_LOADtS size: {ts_size} (data elements per thread)")
|
||||
print(f" 128 threads * {cs_size} coords = {128 * cs_size} (should be 16384)")
|
||||
|
||||
# ====== Try make_cotiled_copy approach ======
|
||||
# We need atom_layout_tv: (tid, vid) → sP flat address
|
||||
# The TMEM-load TV layout gives us (tid, vid) → tStS0 address
|
||||
# tStS0 address encodes (m, k)
|
||||
# We need to compose with a (m, k) → sP address mapping
|
||||
|
||||
# Step 1: Get the TV layout that maps to the S/P matrix coordinates
|
||||
# dst_tv = tiled_tmem_load.layout_dst_tv_tiled
|
||||
# This maps (tid, vid) → coordinate in tStS0
|
||||
# But tStS0 has a specific layout ((128,128),1,1):((65536,1),0,0)
|
||||
# So the TV output is in terms of tStS0's flat address space
|
||||
|
||||
# Step 2: Build a layout that maps tStS0 flat address → sP flat address
|
||||
# tStS0 address = m * 65536 + k (from stride)
|
||||
# We need to map this to the sP address
|
||||
# sP address = sP.layout(m, k%16, (k//16)%4, k//64) (after swizzle)
|
||||
|
||||
# This requires iterating over all (m, k) and computing sP addresses
|
||||
# which we can do at Python time since both layouts are static
|
||||
|
||||
# Let's compute the mapping table
|
||||
sP_outer = p_smem_s.outer
|
||||
sP_swizzle = p_smem_s.inner
|
||||
|
||||
# Coalesce sP to get a flat 1D layout for address computation
|
||||
sP_coalesced = cute.coalesce(sP_outer)
|
||||
print(f"\n=== sP Coalesced Layout ===")
|
||||
print(f" shape: {cute.shape(sP_coalesced)}")
|
||||
print(f" layout: {sP_coalesced}")
|
||||
|
||||
# Compute (m, k) → sP address mapping for all 16384 elements
|
||||
# For verification: build the address table
|
||||
sP_shape = cute.shape(sP_outer)
|
||||
print(f"\n=== Full sP shape: {sP_shape} ===")
|
||||
|
||||
# Let's try a different approach: use composition directly
|
||||
# We know that:
|
||||
# 1. The TMEM-load TV layout maps (tid, vid) → tStS0 coordinate
|
||||
# 2. tStS0's layout has shape (128, 128) in the first group
|
||||
# 3. We need (128, 128) → sP address
|
||||
|
||||
# Build a layout that maps the 128x128 P matrix to sP addresses
|
||||
# For each (m, k) in [0,128) x [0,128):
|
||||
# sP index = ((m, k%16), 0, ((k//16)%4, k//64), 0)
|
||||
# sP address = sP_outer(sP index) (then swizzle applied by CuTe)
|
||||
|
||||
# We can build this as a CuTe layout:
|
||||
# shape = (128, 128), mapping to sP addresses
|
||||
# For k in [0, 128):
|
||||
# k0 = k % 16
|
||||
# k1 = (k // 16) % 4
|
||||
# k2 = k // 64
|
||||
# sP_addr = sP_outer((m, k0), 0, (k1, k2), 0)
|
||||
|
||||
# But CuTe layouts are affine (linear). The (k%16, (k//16)%4, k//64) decomposition
|
||||
# is NOT affine unless 128 = 16 * 4 * 2 exactly (which it is!)
|
||||
# k = k0 + 16*k1 + 64*k2 where k0∈[0,16), k1∈[0,4), k2∈[0,2)
|
||||
# So k0 = k % 16, k1 = (k//16) % 4, k2 = k//64
|
||||
# And the sP shape ((128, 16), 1, (4, 2), 1) has:
|
||||
# mode 0: (128, 16) → stride (64, 1)
|
||||
# mode 2: (4, 2) → stride (16, 8192)
|
||||
# So sP_addr(m, k0, k1, k2) = m*64 + k0*1 + k1*16 + k2*8192
|
||||
|
||||
# Wait, but sP has swizzle S<3,4,3> which XORs some bits
|
||||
# The swizzle is applied by CuTe when we use the tensor directly
|
||||
# For make_cotiled_copy, we need the unswizzled (logical) layout
|
||||
# because CuTe handles the swizzle at access time
|
||||
|
||||
# Actually, let me re-read make_cotiled_copy's requirement:
|
||||
# atom_layout_tv: (tid, vid) -> data addr
|
||||
# data_layout: data coord -> data addr
|
||||
#
|
||||
# The "data addr" here is the LOGICAL address (before swizzle)
|
||||
# because CuTe applies swizzle at tensor access time
|
||||
# So we need the sP outer layout WITHOUT swizzle for data_layout
|
||||
# But p_smem_s.outer IS the logical layout (swizzle is in p_smem_s.inner)
|
||||
|
||||
# So for make_cotiled_copy:
|
||||
# data_layout = sP_outer (logical, no swizzle)
|
||||
# atom_layout_tv: (tid, vid) -> sP logical address
|
||||
|
||||
# The sP logical layout with shape ((128,16),1,(4,2),1) and strides
|
||||
# can be coalesced to a 1D layout mapping (m, k) → address
|
||||
|
||||
# Let me compute the sP logical layout as a (128, 128) → address mapping
|
||||
# sP_outer((m, k0), 0, (k1, k2), 0) = m*64 + k0*1 + k1*16 + k2*8192
|
||||
|
||||
# If we view this as a function of (m, k) where k = k0 + 16*k1 + 64*k2:
|
||||
# sP_addr(m, k) = m*64 + (k%16)*1 + ((k//16)%4)*16 + (k//64)*8192
|
||||
# = m*64 + (k%16) + 16*((k//16)%4) + 8192*(k//64)
|
||||
|
||||
# This is NOT a simple affine function of k alone (due to the modular decomposition)
|
||||
# But it IS affine if we treat k as decomposed into (k0, k1, k2)
|
||||
# k = k0 + 16*k1 + 64*k2
|
||||
# sP_addr = 64*m + 1*k0 + 16*k1 + 8192*k2
|
||||
|
||||
# So the sP logical layout, viewed as a (128, 128) → address map,
|
||||
# has stride (64, ?) where the k-stride varies based on k's position
|
||||
# This is the standard column-major with subtiling
|
||||
|
||||
# For make_cotiled_copy, we need to express this as a CuTe layout
|
||||
# The sP outer layout in its native form IS the correct data_layout
|
||||
|
||||
print("\n=== Plan for make_cotiled_copy ===")
|
||||
print("1. data_layout = coalesce(sP_outer) — flat 1D sP address space")
|
||||
print("2. atom_layout_tv = compose(tiled_tmem_load.layout_dst_tv_tiled, tStS_to_sP_mapping)")
|
||||
print("3. tStS_to_sP_mapping: tStS address → sP address")
|
||||
print(" tStS layout: (128,128) → address, stride=(65536,1)")
|
||||
print(" sP layout: (128,16,4,2) → address (after coalescing), strides vary")
|
||||
print("")
|
||||
print("The problem: tStS has stride 65536 for m, while sP has stride 64 for m")
|
||||
print("These are fundamentally different address spaces.")
|
||||
print("We need to: (tid,vid) → (m,k) via identity, then (m,k) → sP_addr via sP layout")
|
||||
|
||||
# Actually, the cleaner approach:
|
||||
# 1. Build identity tensor for (128, 128) P matrix
|
||||
# 2. The TMEM-load partition + coordinate partition gives us (tid, vid) → (m, k)
|
||||
# 3. The sP layout gives us (m, k) → sP address
|
||||
# 4. Compose these to get (tid, vid) → sP address
|
||||
|
||||
# The identity tensor cS = make_identity_tensor((128, 128))
|
||||
# maps (m, k) → (m, k) with layout (128, 128):((1, 128)) — row-major
|
||||
# Wait, identity tensor layout is just (128, 128) with stride (1, 128)? Let me check.
|
||||
|
||||
# Actually, make_identity_tensor creates a tensor where each element holds its coordinate
|
||||
# The layout is just an enumerate layout: (128, 128):(1, 128) in row-major
|
||||
# So cS(m, k) = m + 128*k (flat index in row-major order)
|
||||
|
||||
# The tScS = qk_thr.partition_C(cS) partitions the identity by QK C-fragment threads
|
||||
# The tTMEM_LOADcS = thr_load.partition_D(tScS) further partitions by TMEM-load threads
|
||||
|
||||
# So tTMEM_LOADcS maps (thread_local_index) → (m, k) coordinate
|
||||
# The shape is ((32,1),4,1,1) with layout ((1@1,0),32@1,0,0)
|
||||
# For index ((j0,0), j1, 0, 0):
|
||||
# value = (j0, j1*32) — thread 0's coordinates
|
||||
|
||||
# But we need the FULL (tid, vid) → (m, k) mapping for ALL 128 threads
|
||||
# The tTMEM_LOADcS is for a SINGLE thread (thread 0)
|
||||
|
||||
# The TV layout (layout_dst_tv_tiled) gives us the FULL mapping:
|
||||
# (tid, vid) → destination coordinate (in tStS0's address space)
|
||||
# But the destination is tStS0 which has layout ((128,128),1,1):((65536,1),0,0)
|
||||
# So the TV output is in tStS0's flat address: m*65536 + k
|
||||
|
||||
# To convert to sP address, we need to:
|
||||
# 1. Extract (m, k) from tStS0 flat address: m = addr // 65536, k = addr % 65536
|
||||
# But this isn't simple because CuTe layouts aren't easily decomposable
|
||||
# 2. Map (m, k) to sP address: use sP_outer((m, k%16), 0, ((k//16)%4, k//64), 0)
|
||||
|
||||
# Let's try using composition directly.
|
||||
# We have:
|
||||
# dst_tv: (tid, vid) → tStS0_addr
|
||||
# tStS0 layout: ((128,128),1,1):((65536,1),0,0)
|
||||
# We need: (tid, vid) → sP_addr
|
||||
|
||||
# Approach: build a "reindex" layout that maps tStS0_addr → sP_addr
|
||||
# For each possible tStS0_addr, compute the corresponding sP_addr
|
||||
|
||||
# But tStS0_addr = m*65536 + k, and we need to decompose back to (m, k)
|
||||
# With m ∈ [0, 128), k ∈ [0, 128), and stride 65536 for m and 1 for k
|
||||
# We can build the inverse: given addr, m = addr // 65536, k = addr % 65536
|
||||
# Then sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
|
||||
|
||||
# This is a 16384-element lookup table. We can build it at Python time.
|
||||
|
||||
# Let me build the lookup table
|
||||
print(f"\n=== Building (tid, vid) → sP_addr mapping ===")
|
||||
|
||||
# Get the TV layout
|
||||
tv = dst_tv
|
||||
tv_shape = tv.shape
|
||||
tv_stride = tv.stride
|
||||
|
||||
# We need to evaluate the TV layout for all (tid, vid) pairs
|
||||
# and then map the resulting tStS0 address to sP address
|
||||
|
||||
# First, let's understand the TV layout structure
|
||||
# It should have shape (num_threads, values_per_thread)
|
||||
# num_threads = 128 (4 softmax warps)
|
||||
# values_per_thread = 128 (32*4 from the coordinate shape)
|
||||
|
||||
print(f" TV shape: {tv_shape}")
|
||||
print(f" TV stride: {tv_stride}")
|
||||
print(f" Total elements: {cute.size(tv)} (should be 16384)")
|
||||
|
||||
# Evaluate the TV layout to get all (tid, vid) → tStS0_addr values
|
||||
# We can do this by iterating over (tid, vid) and computing the layout value
|
||||
|
||||
# CuTe layout evaluation: layout(idx) = sum(idx_i * stride_i)
|
||||
# For a 2D layout (tid, vid) → addr:
|
||||
# addr = tid * stride[0] + vid * stride[1] (simplified)
|
||||
|
||||
# But the TV layout might have nested shapes/strides
|
||||
# Let me just try to evaluate it
|
||||
|
||||
# Actually, for make_cotiled_copy, I don't need to manually compute the mapping.
|
||||
# I can use CuTe's composition operation:
|
||||
# atom_layout_tv = composition(sP_flat_layout, dst_tv)
|
||||
# This would compose dst_tv: (tid, vid) → tStS0_addr
|
||||
# with a mapping tStS0_addr → sP_addr
|
||||
|
||||
# But I need to build the tStS0_addr → sP_addr mapping first
|
||||
# Let me build it as a CuTe layout
|
||||
|
||||
# tStS0 has layout ((128,128),1,1):((65536,1),0,0)
|
||||
# Coalescing: (128,128):(65536,1) — so tStS0_addr = m*65536 + k
|
||||
# The inverse is: m = addr // 65536, k = addr % 65536
|
||||
|
||||
# sP has layout ((128,16),1,(4,2),1):((64,1),0,(16,8192),0) (before swizzle)
|
||||
# Coalescing: ((128,16,4,2),1,1,1):((64,1,16,8192),0,0,0) → flat
|
||||
# sP_addr = 64*m + 1*k0 + 16*k1 + 8192*k2
|
||||
# = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
|
||||
|
||||
# To build tStS0_addr → sP_addr as a CuTe layout:
|
||||
# We need a layout with shape matching tStS0's domain (16384 elements)
|
||||
# But CuTe layouts are affine, and the mapping is NOT affine
|
||||
# (due to the modular decomposition of k)
|
||||
|
||||
# So we CAN'T use composition directly. We need to build a custom layout.
|
||||
|
||||
# Alternative: build the sP layout as a (128, 128) → addr layout directly
|
||||
# sP_addr(m, k) = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
|
||||
# This can be expressed as:
|
||||
# sP_addr = 64*m + f(k) where f(k) = (k%16) + 16*((k//16)%4) + 8192*(k//64)
|
||||
|
||||
# f(k) IS representable as a CuTe layout because 16*4*2 = 128
|
||||
# k can be decomposed as (k0, k1, k2) where k0=k%16, k1=(k//16)%4, k2=k//64
|
||||
# f(k) = k0 + 16*k1 + 8192*k2 = (k0, k1, k2) → (1, 16, 8192) → addr
|
||||
|
||||
# So the full sP layout in (m, k) coordinates:
|
||||
# (m, k) → (m, k0, k1, k2) → (64*m + k0 + 16*k1 + 8192*k2)
|
||||
# = (m, k0, k1, k2) → (64, 1, 16, 8192) → sP_addr
|
||||
|
||||
# And the tStS0 layout:
|
||||
# (m, k) → (65536*m + k) → tStS0_addr
|
||||
|
||||
# The mapping tStS0_addr → (m, k) is the INVERSE of the tStS0 layout:
|
||||
# (65536, 1) → (m, k) — this is what left_inverse gives
|
||||
|
||||
# Then (m, k) → sP_addr via the sP layout
|
||||
|
||||
# So the full chain
|
||||
Reference in New Issue
Block a user