diag: layout composition test for make_cotiled_copy SMEM-P
This commit is contained in:
276
tests/unit/test_cotiled_diag.py
Normal file
276
tests/unit/test_cotiled_diag.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""Minimal diagnostic: can we build atom_layout_tv for make_cotiled_copy?
|
||||
|
||||
Tests the layout composition chain:
|
||||
layout_dst_tv_tiled: (tid, vid) → tStS0_addr
|
||||
left_inverse(tStS_coalesced): tStS0_addr → (m, k) flat index
|
||||
sP_reindexed: (m, k) flat index → sP_addr
|
||||
|
||||
Result: atom_layout_tv = sP_reindexed ∘ left_inverse(tStS_coalesced) ∘ layout_dst_tv_tiled
|
||||
"""
|
||||
import torch, math
|
||||
import cutlass, cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass import Float32, BFloat16
|
||||
import cutlass.torch as ct
|
||||
|
||||
|
||||
def main():
|
||||
head_dim = 256
|
||||
s_k = 128
|
||||
m = 128
|
||||
pv_n_tile = min(head_dim, 256)
|
||||
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
v_fmha = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
||||
mV = ct.from_dlpack(v_fmha).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_fmha))
|
||||
|
||||
a_major = cute.LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = cute.LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
v_major = cute.LayoutEnum.from_tensor(mV).mma_major_mode()
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
|
||||
)
|
||||
pv_source = tcgen05.OperandSource.SMEM
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, v_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source
|
||||
)
|
||||
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
qk_mma_tiler = (128, 128, qk_ik * 4)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
|
||||
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
|
||||
)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
|
||||
# ===========================
|
||||
# Step 1: Get TV layout
|
||||
# ===========================
|
||||
dst_tv = tiled_tmem_load.layout_dst_tv_tiled
|
||||
print(f"dst_tv shape: {dst_tv.shape}")
|
||||
print(f"dst_tv stride: {dst_tv.stride}")
|
||||
print(f"dst_tv size: {cute.size(dst_tv)}")
|
||||
|
||||
# ===========================
|
||||
# Step 2: Get sP layout (coalesced, logical, no swizzle)
|
||||
# ===========================
|
||||
sP_outer = p_smem_s.outer
|
||||
sP_coalesced = cute.coalesce(sP_outer)
|
||||
print(f"\nsP outer shape: {cute.shape(sP_outer)}")
|
||||
print(f"sP outer: {sP_outer}")
|
||||
print(f"sP coalesced shape: {cute.shape(sP_coalesced)}")
|
||||
print(f"sP coalesced: {sP_coalesced}")
|
||||
|
||||
# ===========================
|
||||
# Step 3: Get tStS0 layout (coalesced)
|
||||
# ===========================
|
||||
tStS_coalesced = cute.coalesce(tStS0.layout)
|
||||
print(f"\ntStS layout: {tStS0.layout}")
|
||||
print(f"tStS coalesced: {tStS_coalesced}")
|
||||
print(f"tStS coalesced shape: {cute.shape(tStS_coalesced)}")
|
||||
|
||||
# ===========================
|
||||
# Step 4: Try left_inverse of tStS
|
||||
# ===========================
|
||||
try:
|
||||
tStS_inv = cute.left_inverse(tStS_coalesced)
|
||||
print(f"\ntStS left_inverse: {tStS_inv}")
|
||||
print(f"tStS left_inverse shape: {cute.shape(tStS_inv)}")
|
||||
except Exception as e:
|
||||
print(f"\ntStS left_inverse FAILED: {e}")
|
||||
return
|
||||
|
||||
# ===========================
|
||||
# Step 5: Build sP in the same (m, k) coordinate space as tStS
|
||||
# ===========================
|
||||
# tStS has layout (128, 128) with stride (65536, 1)
|
||||
# sP has layout ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
|
||||
#
|
||||
# We need to express sP as a layout with the same domain shape as tStS: (128, 128)
|
||||
# i.e., sP_as_128x128: (m, k) → sP_addr
|
||||
#
|
||||
# The sP layout in decomposed form:
|
||||
# (m, k0, k1, k2) → 64*m + 1*k0 + 16*k1 + 8192*k2
|
||||
# where k = k0 + 16*k1 + 64*k2
|
||||
#
|
||||
# This is NOT a simple 2D layout because the k→(k0,k1,k2) decomposition
|
||||
# is not affine in the traditional sense.
|
||||
#
|
||||
# BUT: CuTe layouts support hierarchical/nested shapes.
|
||||
# We can express sP as: shape=((128, (16, 4, 2)), stride=((64, (1, 16, 8192))))
|
||||
# This is a layout with 2 modes:
|
||||
# mode 0: (128,) with stride (64,) — the m dimension
|
||||
# mode 1: (16, 4, 2) with stride (1, 16, 8192) — the k decomposed
|
||||
|
||||
# Let me build this layout explicitly:
|
||||
sP_as_2d = cute.make_layout(
|
||||
(128, (16, 4, 2)),
|
||||
stride=(64, (1, 16, 8192))
|
||||
)
|
||||
print(f"\nsP as 2D layout: {sP_as_2d}")
|
||||
print(f"sP as 2D shape: {cute.shape(sP_as_2d)}")
|
||||
|
||||
# Verify: does this match sP_coalesced?
|
||||
# sP_coalesced is coalesced from ((128,16),1,(4,2),1):((64,1),0,(16,8192),0)
|
||||
# Coalescing removes size-1 modes and merges adjacent modes with compatible strides
|
||||
# Result should be ((128, 16, 4, 2)):((64, 1, 16, 8192))
|
||||
# Or possibly ((128, (16, 4, 2))):((64, (1, 16, 8192)))
|
||||
|
||||
# Let me verify by checking element counts
|
||||
print(f"sP_as_2d size: {cute.size(sP_as_2d)} (should be 16384)")
|
||||
print(f"sP_coalesced size: {cute.size(sP_coalesced)} (should be 16384)")
|
||||
|
||||
# ===========================
|
||||
# Step 6: Build the address mapping
|
||||
# ===========================
|
||||
# We have:
|
||||
# tStS: (128, 128) → addr with stride (65536, 1)
|
||||
# sP: (128, (16,4,2)) → addr with stride (64, (1, 16, 8192))
|
||||
#
|
||||
# Both map (m, k) → address, but with different strides.
|
||||
# We need a layout that maps tStS_addr → sP_addr.
|
||||
#
|
||||
# Since both have the SAME domain (the (m, k) space of the P matrix),
|
||||
# we can compose: sP ∘ left_inverse(tStS)
|
||||
# This gives: tStS_addr → (m, k) → sP_addr
|
||||
|
||||
# But left_inverse(tStS) might not produce a simple (m, k) layout
|
||||
# because tStS has non-compact strides.
|
||||
|
||||
# Alternative: use composition with matching domain shapes.
|
||||
# tStS shape is (128, 128), sP_as_2d shape is (128, (16, 4, 2))
|
||||
# They have the same logical domain: (128, 128) ≡ (128, (16, 4, 2))
|
||||
# because 16*4*2 = 128
|
||||
|
||||
# CuTe's composition should handle this if the shapes are compatible.
|
||||
# But composition(A, B) requires B's codomain to be A's domain.
|
||||
# And the shapes need to match.
|
||||
|
||||
# Let me try a different approach: build the reindex layout directly.
|
||||
# reindex: maps the tStS address space to the sP address space
|
||||
# For each (m, k) pair:
|
||||
# tStS_addr = 65536*m + k
|
||||
# sP_addr = 64*m + k%16 + 16*(k//16%4) + 8192*(k//64)
|
||||
|
||||
# Since both are functions of (m, k), I can compose sP with left_inverse(tStS):
|
||||
try:
|
||||
# left_inverse(tStS_coalesced) maps tStS_addr → index in [0, 16384)
|
||||
# But the "index" from left_inverse is the logical coordinate in tStS's domain
|
||||
# For tStS with stride (65536, 1), left_inverse maps:
|
||||
# addr → (addr // 65536, addr % 65536) = (m, k)
|
||||
# But as a CuTe layout, the result might be (m*1 + k*something)
|
||||
# Actually, left_inverse of a layout with stride (65536, 1) and shape (128, 128)
|
||||
# gives a layout that maps addr → index (flat 1D or structured)
|
||||
|
||||
print("\n=== Trying composition chain ===")
|
||||
|
||||
# Step A: left_inverse(tStS_coalesced)
|
||||
tStS_inv = cute.left_inverse(tStS_coalesced)
|
||||
print(f"tStS_inv: {tStS_inv}")
|
||||
|
||||
# Step B: Compose sP with tStS_inv
|
||||
# This gives: tStS_addr → (m, k) → sP_addr
|
||||
# But we need to match the domain/codomain shapes
|
||||
|
||||
# tStS_inv maps to a coordinate space matching tStS's domain shape
|
||||
# sP_as_2d maps from a coordinate space matching sP's domain shape
|
||||
# If both have domain shape (128, 128) ≡ (128, (16,4,2)),
|
||||
# composition should work
|
||||
|
||||
try:
|
||||
reindex = cute.composition(sP_as_2d, tStS_inv)
|
||||
print(f"reindex layout: {reindex}")
|
||||
print(f"reindex shape: {cute.shape(reindex)}")
|
||||
except Exception as e:
|
||||
print(f"composition(sP, tStS_inv) FAILED: {e}")
|
||||
|
||||
# Try with coalesced versions
|
||||
try:
|
||||
reindex = cute.composition(sP_coalesced, tStS_inv)
|
||||
print(f"reindex (coalesced) layout: {reindex}")
|
||||
except Exception as e2:
|
||||
print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Composition chain failed: {e}")
|
||||
|
||||
# ===========================
|
||||
# Step 7: If composition works, try make_cotiled_copy
|
||||
# ===========================
|
||||
try:
|
||||
# atom_layout_tv = composition(reindex, dst_tv)
|
||||
# This maps (tid, vid) → tStS_addr → sP_addr
|
||||
atom_layout_tv = cute.composition(reindex, dst_tv)
|
||||
print(f"\natom_layout_tv: {atom_layout_tv}")
|
||||
print(f"atom_layout_tv shape: {cute.shape(atom_layout_tv)}")
|
||||
|
||||
# Try make_cotiled_copy
|
||||
r2s_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
BFloat16,
|
||||
num_bits_per_copy=16,
|
||||
)
|
||||
tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced)
|
||||
print(f"\nmake_cotiled_copy SUCCEEDED!")
|
||||
print(f"tiled_r2s layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
|
||||
print(f"tiled_r2s layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
|
||||
|
||||
# Try partition
|
||||
thr_r2s = tiled_r2s.get_slice(0)
|
||||
print(f"partition_D and partition_S work!")
|
||||
except Exception as e:
|
||||
print(f"\nmake_cotiled_copy FAILED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# ===========================
|
||||
# Alternative: Try make_tiled_copy_tv
|
||||
# ===========================
|
||||
# We know the softmax threads (128 total) and each thread has 128 values
|
||||
# thr_layout: (128, 1) → tid (or some other layout matching the TMEM-load)
|
||||
# val_layout: (32, 4) → vid (from the coordinate shape ((32,1),4))
|
||||
|
||||
# The TMEM-load's TV layout tells us the thread and value structure
|
||||
print(f"\n=== TV layout analysis ===")
|
||||
tv = dst_tv
|
||||
# If tv has shape (128, 128), that's 128 threads × 128 values
|
||||
# The thread layout maps (tid_group) → tid, and val layout maps (vid_group) → vid
|
||||
# But we need to decompose the TV layout into separate thread and value parts
|
||||
|
||||
# For make_tiled_copy_tv, we need:
|
||||
# thr_layout: mapping from tile coordinates to thread IDs
|
||||
# val_layout: mapping from value coordinates to value IDs
|
||||
# These should be "compact" (i.e., the layout produces consecutive thread/value IDs)
|
||||
|
||||
# The TV layout for our TMEM load likely has:
|
||||
# 128 threads (4 warps × 32 threads)
|
||||
# 128 values per thread (32×4 from coordinate shape)
|
||||
|
||||
# If the TV layout has shape ((A, B), (C, D)) with some nesting,
|
||||
# we need to figure out which dimensions are "thread" and which are "value"
|
||||
|
||||
print(f"Full TV layout: {tv}")
|
||||
print(f"TV shape raw: {tv.shape}")
|
||||
print(f"TV stride raw: {tv.stride}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user