diag: layout composition test for make_cotiled_copy SMEM-P

This commit is contained in:
2026-05-24 01:48:42 +00:00
parent 699c646497
commit 67a2c3ee72

View File

@@ -0,0 +1,276 @@
"""Minimal diagnostic: can we build atom_layout_tv for make_cotiled_copy?
Tests the layout composition chain:
layout_dst_tv_tiled: (tid, vid) → tStS0_addr
left_inverse(tStS_coalesced): tStS0_addr → (m, k) flat index
sP_reindexed: (m, k) flat index → sP_addr
Result: atom_layout_tv = sP_reindexed ∘ left_inverse(tStS_coalesced) ∘ layout_dst_tv_tiled
"""
import torch, math
import cutlass, cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
import cutlass.torch as ct
def main():
head_dim = 256
s_k = 128
m = 128
pv_n_tile = min(head_dim, 256)
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_fmha = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
mV = ct.from_dlpack(v_fmha).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_fmha))
a_major = cute.LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = cute.LayoutEnum.from_tensor(mK).mma_major_mode()
v_major = cute.LayoutEnum.from_tensor(mV).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
)
pv_source = tcgen05.OperandSource.SMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, v_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source
)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
# ===========================
# Step 1: Get TV layout
# ===========================
dst_tv = tiled_tmem_load.layout_dst_tv_tiled
print(f"dst_tv shape: {dst_tv.shape}")
print(f"dst_tv stride: {dst_tv.stride}")
print(f"dst_tv size: {cute.size(dst_tv)}")
# ===========================
# Step 2: Get sP layout (coalesced, logical, no swizzle)
# ===========================
sP_outer = p_smem_s.outer
sP_coalesced = cute.coalesce(sP_outer)
print(f"\nsP outer shape: {cute.shape(sP_outer)}")
print(f"sP outer: {sP_outer}")
print(f"sP coalesced shape: {cute.shape(sP_coalesced)}")
print(f"sP coalesced: {sP_coalesced}")
# ===========================
# Step 3: Get tStS0 layout (coalesced)
# ===========================
tStS_coalesced = cute.coalesce(tStS0.layout)
print(f"\ntStS layout: {tStS0.layout}")
print(f"tStS coalesced: {tStS_coalesced}")
print(f"tStS coalesced shape: {cute.shape(tStS_coalesced)}")
# ===========================
# Step 4: Try left_inverse of tStS
# ===========================
try:
tStS_inv = cute.left_inverse(tStS_coalesced)
print(f"\ntStS left_inverse: {tStS_inv}")
print(f"tStS left_inverse shape: {cute.shape(tStS_inv)}")
except Exception as e:
print(f"\ntStS left_inverse FAILED: {e}")
return
# ===========================
# Step 5: Build sP in the same (m, k) coordinate space as tStS
# ===========================
# tStS has layout (128, 128) with stride (65536, 1)
# sP has layout ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
#
# We need to express sP as a layout with the same domain shape as tStS: (128, 128)
# i.e., sP_as_128x128: (m, k) → sP_addr
#
# The sP layout in decomposed form:
# (m, k0, k1, k2) → 64*m + 1*k0 + 16*k1 + 8192*k2
# where k = k0 + 16*k1 + 64*k2
#
# This is NOT a simple 2D layout because the k→(k0,k1,k2) decomposition
# is not affine in the traditional sense.
#
# BUT: CuTe layouts support hierarchical/nested shapes.
# We can express sP as: shape=((128, (16, 4, 2)), stride=((64, (1, 16, 8192))))
# This is a layout with 2 modes:
# mode 0: (128,) with stride (64,) — the m dimension
# mode 1: (16, 4, 2) with stride (1, 16, 8192) — the k decomposed
# Let me build this layout explicitly:
sP_as_2d = cute.make_layout(
(128, (16, 4, 2)),
stride=(64, (1, 16, 8192))
)
print(f"\nsP as 2D layout: {sP_as_2d}")
print(f"sP as 2D shape: {cute.shape(sP_as_2d)}")
# Verify: does this match sP_coalesced?
# sP_coalesced is coalesced from ((128,16),1,(4,2),1):((64,1),0,(16,8192),0)
# Coalescing removes size-1 modes and merges adjacent modes with compatible strides
# Result should be ((128, 16, 4, 2)):((64, 1, 16, 8192))
# Or possibly ((128, (16, 4, 2))):((64, (1, 16, 8192)))
# Let me verify by checking element counts
print(f"sP_as_2d size: {cute.size(sP_as_2d)} (should be 16384)")
print(f"sP_coalesced size: {cute.size(sP_coalesced)} (should be 16384)")
# ===========================
# Step 6: Build the address mapping
# ===========================
# We have:
# tStS: (128, 128) → addr with stride (65536, 1)
# sP: (128, (16,4,2)) → addr with stride (64, (1, 16, 8192))
#
# Both map (m, k) → address, but with different strides.
# We need a layout that maps tStS_addr → sP_addr.
#
# Since both have the SAME domain (the (m, k) space of the P matrix),
# we can compose: sP ∘ left_inverse(tStS)
# This gives: tStS_addr → (m, k) → sP_addr
# But left_inverse(tStS) might not produce a simple (m, k) layout
# because tStS has non-compact strides.
# Alternative: use composition with matching domain shapes.
# tStS shape is (128, 128), sP_as_2d shape is (128, (16, 4, 2))
# They have the same logical domain: (128, 128) ≡ (128, (16, 4, 2))
# because 16*4*2 = 128
# CuTe's composition should handle this if the shapes are compatible.
# But composition(A, B) requires B's codomain to be A's domain.
# And the shapes need to match.
# Let me try a different approach: build the reindex layout directly.
# reindex: maps the tStS address space to the sP address space
# For each (m, k) pair:
# tStS_addr = 65536*m + k
# sP_addr = 64*m + k%16 + 16*(k//16%4) + 8192*(k//64)
# Since both are functions of (m, k), I can compose sP with left_inverse(tStS):
try:
# left_inverse(tStS_coalesced) maps tStS_addr → index in [0, 16384)
# But the "index" from left_inverse is the logical coordinate in tStS's domain
# For tStS with stride (65536, 1), left_inverse maps:
# addr → (addr // 65536, addr % 65536) = (m, k)
# But as a CuTe layout, the result might be (m*1 + k*something)
# Actually, left_inverse of a layout with stride (65536, 1) and shape (128, 128)
# gives a layout that maps addr → index (flat 1D or structured)
print("\n=== Trying composition chain ===")
# Step A: left_inverse(tStS_coalesced)
tStS_inv = cute.left_inverse(tStS_coalesced)
print(f"tStS_inv: {tStS_inv}")
# Step B: Compose sP with tStS_inv
# This gives: tStS_addr → (m, k) → sP_addr
# But we need to match the domain/codomain shapes
# tStS_inv maps to a coordinate space matching tStS's domain shape
# sP_as_2d maps from a coordinate space matching sP's domain shape
# If both have domain shape (128, 128) ≡ (128, (16,4,2)),
# composition should work
try:
reindex = cute.composition(sP_as_2d, tStS_inv)
print(f"reindex layout: {reindex}")
print(f"reindex shape: {cute.shape(reindex)}")
except Exception as e:
print(f"composition(sP, tStS_inv) FAILED: {e}")
# Try with coalesced versions
try:
reindex = cute.composition(sP_coalesced, tStS_inv)
print(f"reindex (coalesced) layout: {reindex}")
except Exception as e2:
print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}")
except Exception as e:
print(f"Composition chain failed: {e}")
# ===========================
# Step 7: If composition works, try make_cotiled_copy
# ===========================
try:
# atom_layout_tv = composition(reindex, dst_tv)
# This maps (tid, vid) → tStS_addr → sP_addr
atom_layout_tv = cute.composition(reindex, dst_tv)
print(f"\natom_layout_tv: {atom_layout_tv}")
print(f"atom_layout_tv shape: {cute.shape(atom_layout_tv)}")
# Try make_cotiled_copy
r2s_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
num_bits_per_copy=16,
)
tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced)
print(f"\nmake_cotiled_copy SUCCEEDED!")
print(f"tiled_r2s layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
print(f"tiled_r2s layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
# Try partition
thr_r2s = tiled_r2s.get_slice(0)
print(f"partition_D and partition_S work!")
except Exception as e:
print(f"\nmake_cotiled_copy FAILED: {e}")
import traceback
traceback.print_exc()
# ===========================
# Alternative: Try make_tiled_copy_tv
# ===========================
# We know the softmax threads (128 total) and each thread has 128 values
# thr_layout: (128, 1) → tid (or some other layout matching the TMEM-load)
# val_layout: (32, 4) → vid (from the coordinate shape ((32,1),4))
# The TMEM-load's TV layout tells us the thread and value structure
print(f"\n=== TV layout analysis ===")
tv = dst_tv
# If tv has shape (128, 128), that's 128 threads × 128 values
# The thread layout maps (tid_group) → tid, and val layout maps (vid_group) → vid
# But we need to decompose the TV layout into separate thread and value parts
# For make_tiled_copy_tv, we need:
# thr_layout: mapping from tile coordinates to thread IDs
# val_layout: mapping from value coordinates to value IDs
# These should be "compact" (i.e., the layout produces consecutive thread/value IDs)
# The TV layout for our TMEM load likely has:
# 128 threads (4 warps × 32 threads)
# 128 values per thread (32×4 from coordinate shape)
# If the TV layout has shape ((A, B), (C, D)) with some nesting,
# we need to figure out which dimensions are "thread" and which are "value"
print(f"Full TV layout: {tv}")
print(f"TV shape raw: {tv.shape}")
print(f"TV stride raw: {tv.stride}")
if __name__ == '__main__':
main()