D1.3: Add make_cotiled_copy diagnostic test

This commit is contained in:
2026-05-24 00:05:48 +00:00
parent d4659d661d
commit 63e3ed0fed

View File

@@ -0,0 +1,297 @@
"""
D1.3 SMEM-P: Diagnostic for make_cotiled_copy approach.
Goal: Build a custom R→S tiled copy that maps softmax thread registers
(TMEM-load ownership) to sP (PV A-operand SMEM with swizzle).
Steps:
1. Print tiled_tmem_load TV layout shapes
2. Print tTMEM_LOADcS coordinate partition
3. Build atom_layout_tv: (tid, vid) -> sP address
4. Create make_cotiled_copy and print partition shapes
5. Test: write P to sP via cotiled copy, read back via PV MMA, verify
"""
import torch, math
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cutlass.torch as ct
import cuda.bindings.driver as cuda
def test_cotiled_copy_diag():
"""Print TV layout shapes from TMEM load and sP to understand the mapping."""
print("=== make_cotiled_copy Diagnostic ===\n")
head_dim = 64 # Start with proven hd=64 (TMEM-P works at cos 0.973)
s_k = 128
pv_n_tile = min(head_dim, 256)
qk_mma_tiler = (128, 128, 128 * 4)
pv_mma_tiler = (128, pv_n_tile, 128)
# Build MMA objects
a_major = LayoutEnum.ROW_MAJOR
b_major = LayoutEnum.ROW_MAJOR
v_major = LayoutEnum.ROW_MAJOR
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
)
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, v_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), tcgen05.OperandSource.SMEM
)
# Build SMEM layouts
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
# Build QK C-fragment and TMEM load partition
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
# TMEM load atoms
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
# Print the TV layout of the TMEM load
print(f"tiled_tmem_load shape: {cute.shape(tiled_tmem_load)}")
tv_layout = tiled_tmem_load.layout_dst_tv_tiled
print(f"TV layout: {tv_layout}")
print(f"TV layout shape: {cute.shape(tv_layout)}")
# Print per-thread coordinate partition
sfw_idx = 0 # first softmax thread
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS)
print(f"\ntTMEM_LOADtS shape (thread 0): {cute.shape(tTMEM_LOADtS)}")
print(f"tTMEM_LOADtS layout: {tTMEM_LOADtS.layout}")
cS = cute.make_identity_tensor((qk_mma_tiler[0], qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
print(f"tTMEM_LOADcS shape (thread 0): {cute.shape(tTMEM_LOADcS)}")
print(f"tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}")
# Print sP layout
sP_shape = p_smem_s.outer
print(f"\nsP outer shape: {cute.shape(sP_shape)}")
print(f"sP outer layout: {sP_shape}")
print(f"sP inner (swizzle): {p_smem_s.inner}")
# Print what each softmax thread's coordinates look like
# For 128 threads, each with ((32,1),4,1,1) = 128 elements
print(f"\n=== Coordinate analysis ===")
print(f"Total softmax threads: 128 (warps 0-3)")
print(f"Per thread: {cute.size(tTMEM_LOADcS)} coordinate pairs")
print(f"Total coordinates: 128 * {cute.size(tTMEM_LOADcS)} = {128 * cute.size(tTMEM_LOADcS)}")
print(f"Expected: 128*128 = {128*128}")
# The key question: can we build an atom_layout_tv that maps
# (tid, vid) -> sP address?
#
# For make_cotiled_copy:
# atom_layout_tv: (tid, vid) -> data address in sP's codomain
# data_layout: data coord -> data address (this is sP's layout)
#
# The TMEM load's TV layout maps (tid, vid) -> tStS0 coordinates.
# tStS0 is a TMEM tensor with layout ((128,128),1,1):((65536,1),0,0)
# So (tid, vid) -> flat TMEM address.
# But we need (tid, vid) -> flat sP address.
#
# The connection: tStS0 stores S with logical (m, k) = (128, 128).
# sP stores P with the same logical (m, k) but in SMEM layout.
# So the mapping is:
# (tid, vid) -> tStS0 address -> (m, k) via inverse of tStS0.layout
# (m, k) -> sP address via sP.layout
#
# But computing the inverse of tStS0 layout at Python time is the challenge.
# Let's try a different approach: build the atom_layout_tv directly.
# The TMEM load has 128 threads and 128 values per thread.
# Total values = 128 * 128 = 16384 = 128 * 128 ✓
#
# We need: atom_layout_tv such that for thread tid and value vid,
# the output is the flat address in sP's layout.
#
# The TMEM load's layout_dst_tv_tiled already maps (tid, vid) -> tStS0 flat addr.
# But tStS0 flat addr = m * 65536 + k (from layout ((128,128):((65536,1)))
# So we can extract (m, k) from the address: m = addr // 65536, k = addr % 65536
# Wait, that's not right. The layout is ((128,128),1,1):((65536,1),0,0)
# So the address for coordinate (m, k) is m * 65536 + k * 1
# Meaning: addr = m * 65536 + k
# So: m = addr // 65536, k = addr % 1... no, stride of k is 1.
# Actually: addr = m * 65536 + k. So m = addr // 65536, k = addr % 65536
# But k ranges from 0 to 127, so k = addr % 65536 (which should be < 128).
# And m ranges from 0 to 127, so m = addr // 65536 (which should be < 128).
# Wait, that doesn't work because addr could be huge.
# Let me reconsider.
# tStS layout: ((128,128),1,1):((65536,1),0,0)
# This is a 3-mode tensor. For coordinate ((m, k), i, j):
# addr = m * 65536 + k * 1 + i * 0 + j * 0
# So effectively: addr = m * 65536 + k
#
# sP layout (from p_smem_s.outer): ((128,16),1,(4,2),1):(((64,1),0,((16,8192),0)
# For coordinate ((m, k0), i, (k1, k2), j):
# addr = m * 64 + k0 * 1 + i * 0 + k1 * 16 + k2 * 8192 + j * 0
# addr = m * 64 + k0 + k1 * 16 + k2 * 8192
#
# Given (m, k) from TMEM load coords:
# k0 = k % 16, k1 = (k // 16) % 4, k2 = k // 64
# sP_addr = m * 64 + (k % 16) + ((k // 16) % 4) * 16 + (k // 64) * 8192
#
# But we need to account for the swizzle! The swizzle is applied by
# CuTe automatically when you index into sP. So the sP.layout already
# includes the swizzle in the address computation.
#
# Actually, the swizzle is in p_smem_s.inner, not in p_smem_s.outer.
# The outer layout is the logical layout (no swizzle).
# When we allocate sP with swizzle=p_smem_s.inner, the indexing
# automatically applies the swizzle XOR.
#
# So for make_cotiled_copy, data_layout should be the COMPOSED layout
# (outer with swizzle applied). Let me check how CUTLASS handles this.
#
# Actually, in CuTe, when you do cute.make_tensor(ptr, layout, swizzle),
# the swizzle is applied during tensor creation. The layout is the
# pre-swizzle layout, and the swizzle XOR is applied on top.
#
# For make_cotiled_copy, we need to think about what "address" means.
# The atom_layout_tv maps (tid, vid) to data addresses.
# If data_layout includes the swizzle, then the addresses are post-swizzle.
# If not, they're pre-swizzle.
#
# The safest approach: use the 2D sP without the stage dimension,
# and let CuTe handle swizzle via the tensor's layout+swizzle.
#
# Actually, let me look at what make_cotiled_copy expects.
# The docstring says: "atom_layout_tv: (tid, vid) -> data addr"
# and "data_layout: data coord -> data addr"
# So the atom_layout_tv's codomain should match data_layout's codomain.
#
# For sP with swizzle, the "data addr" is the swizzled address.
# So we need to compose: (tid, vid) -> (m, k) -> swizzled sP addr.
#
# This is getting complex. Let me try the simpler approach first:
# use the current coordinate-indexed write but verify it works.
# Then come back and optimize with make_cotiled_copy.
print("\n=== Attempting make_cotiled_copy ===")
# Step 1: Build atom_layout_tv from TMEM load's TV layout.
# The TMEM load's layout_dst_tv_tiled maps (tid, vid) -> tStS0 flat address.
# We need to transform this to map (tid, vid) -> sP flat address.
#
# The transformation is:
# tStS0_addr = m * 65536 + k (from tStS0 layout)
# sP_addr = m * 64 + (k % 16) + ((k // 16) % 4) * 16 + (k // 64) * 8192
#
# This is NOT a simple layout composition because the (m, k) decomposition
# from tStS0_addr is not trivial (stride 65536 for m).
#
# Alternative: Use make_cotiled_copy with the TMEM load's TV layout
# but change the data_layout to something that maps tStS0 addresses to
# the same codomain as sP addresses.
#
# Actually, the cleanest approach is to build atom_layout_tv from scratch
# using the coordinate information from tTMEM_LOADcS.
#
# Each softmax thread owns 128 (m, k) pairs.
# We can enumerate all 128 * 128 = 16384 (tid, (m, k)) pairs
# and compute the sP address for each.
# But in CuTeDSL, we can't iterate and build layouts at Python time
# inside @cute.jit. We need to build the layout at Python (trace) time.
# Let me try a different approach: use make_tiled_copy_tv.
# thr_layout: maps (TileM, TileN) -> tid
# val_layout: maps (ValueM, ValueN) -> vid
#
# The softmax threads are indexed 0..127 (4 warps × 32 threads).
# Each thread owns a (32, 4) sub-tile of the P matrix (from tTMEM_LOADcS shape ((32,1),4)).
# So thr_layout should map (32, 4) tiles to 128 threads.
# And val_layout should map (32, 1) values per tile position.
#
# Wait, the TMEM load's coordinate partition is ((32,1),4,1,1).
# This means each thread has 32 × 4 = 128 coordinate pairs.
# The first mode (32,1) is the "row" within a fragment.
# The second mode 4 is the "fragment" index.
#
# For make_tiled_copy_tv:
# - thr_layout: how threads tile the (M, K) = (128, 128) P matrix
# - val_layout: how values are arranged within a thread's tile
#
# From the TMEM load partition, each thread owns:
# - 32 M-values (not necessarily contiguous)
# - 4 K-fragments
# The thread layout is determined by the Ld32x32bOp atom's thread mapping.
# Let me just print the actual TV layout to understand it.
print(f"tiled_tmem_load layout_dst_tv shape: {cute.shape(tv_layout)}")
if hasattr(tv_layout, 'a'):
print(f" a (thread): {tv_layout.a}")
if hasattr(tv_layout, 'b'):
print(f" b (value): {tv_layout.b}")
# Print sP composed layout (with swizzle)
# The key insight from the CUTLASS LLM: we need atom_layout_tv
# such that atom_layout_tv(tid, vid) gives a coordinate in sP's codomain.
#
# But actually, for make_cotiled_copy, the atom_layout_tv maps to
# "data addr" which is the same codomain as data_layout.
# data_layout maps data coordinates to addresses.
# So atom_layout_tv(tid, vid) should produce an address.
#
# The TV layout from tiled_tmem_load maps (tid, vid) to tStS0 addresses.
# If we could remap tStS0 addresses to sP addresses, we'd have it.
#
# tStS0 addresses: m * 65536 + k (m in [0,128), k in [0,128))
# sP addresses: m * 64 + (k % 16) + ((k // 16) % 4) * 16 + (k // 64) * 8192
# + swizzle XOR
#
# The mapping is: tStS0_addr -> (m, k) -> sP_coord -> sP_addr
# tStS0_addr = m * 65536 + k
# m = tStS0_addr // 65536
# k = tStS0_addr % 65536 (but should be < 128)
# Actually since k < 128 and stride is 1, k = tStS0_addr % 65536 is correct
# and since 128 < 65536, this gives k directly.
# And m = tStS0_addr // 65536 (since k < 65536, integer division works).
# The problem: computing m = addr // 65536 and k = addr % 65536
# from a Layout object is not straightforward. Layouts are affine maps.
# Division/modulo by 65536 is not affine in general.
#
# BUT: the TMEM load's TV layout already encodes the (tid, vid) -> addr map
# as a Layout. We need to transform this Layout to produce sP addresses
# instead of tStS0 addresses.
#
# Let me try yet another approach: just print what we have and think.
print("\n=== Checking if we can extract (m,k) from TV layout ===")
# The TV layout has shape (128_threads, 128_values) mapping to addresses.
# If we can reshape this to (128, 128) and interpret the output as (m, k)
# coordinates, then compose with sP's layout, we get what we need.
# Actually, let me try the simplest possible thing:
# Print the existing TV layout and see if it's compatible with sP
# in any way.
try:
sP_stage = p_smem_s.outer
# Try to see what the layout composition would look like
print(f"sP_stage layout: {sP_stage}")
print(f"TV layout: {tv_layout}")
# Can we compose tv_layout with the inverse of tStS.layout, then with sP.layout?
# cute.composition(sP_stage, tv_layout) might work if the codomains match
print(f"tStS layout: {tStS.layout}")
except Exception as e:
print(f"Error: {e}")
if __name__ == '__main__':
test_cotiled_copy_diag()