diag: simplified cotiled layout test

This commit is contained in:
2026-05-24 01:53:24 +00:00
parent 60ff5e4f53
commit a5fbc6cebd

View File

@@ -1,47 +1,66 @@
"""Minimal diagnostic: can we build atom_layout_tv for make_cotiled_copy?
"""Minimal diagnostic: test layout composition for SMEM-P make_cotiled_copy.
Tests the layout composition chain:
layout_dst_tv_tiled: (tid, vid) → tStS0_addr
left_inverse(tStS_coalesced): tStS0_addr → (m, k) flat index
sP_reindexed: (m, k) flat index → sP_addr
Tests whether we can compose the TMEM-load TV layout with the sP address mapping
to build atom_layout_tv for make_cotiled_copy.
Result: atom_layout_tv = sP_reindexed ∘ left_inverse(tStS_coalesced) ∘ layout_dst_tv_tiled
This test uses the actual FmhaKernel setup to get the real layouts.
"""
import torch, math
import torch, math, sys
import cutlass, cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def main():
head_dim = 256
s_k = 128
m = 128
pv_n_tile = min(head_dim, 256)
# Use FmhaKernel to set up the same layouts
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaKernel(head_dim=head_dim, s_k=s_k)
# Do the same setup as __call__ but extract layouts before launching
pv_n_tile = kernel.pv_n_tile
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_fmha = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
mV = ct.from_dlpack(v_fmha).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_fmha))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c[:, 0:pv_n_tile, :]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c[:, 0:pv_n_tile, :]))
# Reproduce the __call__ setup to extract the layouts
q_dtype = BFloat16
from cutlass.utils import LayoutEnum
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
# v_fmha layout
v_fmha_layout = cute.make_layout(
(pv_n_tile, s_k, 1),
stride=(1, pv_n_tile, pv_n_tile * s_k),
)
v_major = LayoutEnum.ROW_MAJOR # based on the v_fmha layout
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major, Float32,
q_dtype, q_dtype, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
)
pv_source = tcgen05.OperandSource.SMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, v_major, Float32,
q_dtype, q_dtype, a_major, v_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source
)
@@ -50,180 +69,103 @@ def main():
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, q_dtype, 1)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
# TMEM-load copy
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
# ===========================
# Step 1: Get TV layout
# Print layouts
# ===========================
dst_tv = tiled_tmem_load.layout_dst_tv_tiled
print(f"dst_tv shape: {dst_tv.shape}")
print(f"dst_tv stride: {dst_tv.stride}")
print(f"dst_tv size: {cute.size(dst_tv)}")
# ===========================
# Step 2: Get sP layout (coalesced, logical, no swizzle)
# ===========================
sP_outer = p_smem_s.outer
sP_coalesced = cute.coalesce(sP_outer)
print(f"\nsP outer shape: {cute.shape(sP_outer)}")
print(f"sP outer: {sP_outer}")
print(f"sP coalesced shape: {cute.shape(sP_coalesced)}")
print(f"\nsP outer: {sP_outer}")
print(f"sP outer shape: {cute.shape(sP_outer)}")
print(f"sP coalesced: {sP_coalesced}")
# ===========================
# Step 3: Get tStS0 layout (coalesced)
# ===========================
tStS_coalesced = cute.coalesce(tStS0.layout)
print(f"\ntStS layout: {tStS0.layout}")
print(f"tStS coalesced: {tStS_coalesced}")
print(f"tStS coalesced shape: {cute.shape(tStS_coalesced)}")
# ===========================
# Step 4: Try left_inverse of tStS
# Build sP layout in (128, 128) coordinate space
# ===========================
try:
tStS_inv = cute.left_inverse(tStS_coalesced)
print(f"\ntStS left_inverse: {tStS_inv}")
print(f"tStS left_inverse shape: {cute.shape(tStS_inv)}")
except Exception as e:
print(f"\ntStS left_inverse FAILED: {e}")
return
# ===========================
# Step 5: Build sP in the same (m, k) coordinate space as tStS
# ===========================
# tStS has layout (128, 128) with stride (65536, 1)
# sP has layout ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
#
# We need to express sP as a layout with the same domain shape as tStS: (128, 128)
# i.e., sP_as_128x128: (m, k) → sP_addr
#
# The sP layout in decomposed form:
# (m, k0, k1, k2) → 64*m + 1*k0 + 16*k1 + 8192*k2
# where k = k0 + 16*k1 + 64*k2
#
# This is NOT a simple 2D layout because the k→(k0,k1,k2) decomposition
# is not affine in the traditional sense.
#
# BUT: CuTe layouts support hierarchical/nested shapes.
# We can express sP as: shape=((128, (16, 4, 2)), stride=((64, (1, 16, 8192))))
# This is a layout with 2 modes:
# mode 0: (128,) with stride (64,) — the m dimension
# mode 1: (16, 4, 2) with stride (1, 16, 8192) — the k decomposed
# Let me build this layout explicitly:
sP_as_2d = cute.make_layout(
# sP outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
# This is the same as (128, (16, 4, 2)) with strides (64, (1, 16, 8192))
# Let me build it explicitly:
sP_2d = cute.make_layout(
(128, (16, 4, 2)),
stride=(64, (1, 16, 8192))
)
print(f"\nsP as 2D layout: {sP_as_2d}")
print(f"sP as 2D shape: {cute.shape(sP_as_2d)}")
print(f"\nsP_2d: {sP_2d}")
print(f"sP_2d size: {cute.size(sP_2d)}")
# Verify: does this match sP_coalesced?
# sP_coalesced is coalesced from ((128,16),1,(4,2),1):((64,1),0,(16,8192),0)
# Coalescing removes size-1 modes and merges adjacent modes with compatible strides
# Result should be ((128, 16, 4, 2)):((64, 1, 16, 8192))
# Or possibly ((128, (16, 4, 2))):((64, (1, 16, 8192)))
# Let me verify by checking element counts
print(f"sP_as_2d size: {cute.size(sP_as_2d)} (should be 16384)")
print(f"sP_coalesced size: {cute.size(sP_coalesced)} (should be 16384)")
# tStS coalesced: (128, 128) with stride (65536, 1) typically
# Let me check the exact strides
print(f"\ntStS_coalesced shape: {cute.shape(tStS_coalesced)}")
print(f"tStS_coalesced stride: {tStS_coalesced.stride}")
# ===========================
# Step 6: Build the address mapping
# Try left_inverse(tStS)
# ===========================
# We have:
# tStS: (128, 128) → addr with stride (65536, 1)
# sP: (128, (16,4,2)) → addr with stride (64, (1, 16, 8192))
#
# Both map (m, k) → address, but with different strides.
# We need a layout that maps tStS_addr → sP_addr.
#
# Since both have the SAME domain (the (m, k) space of the P matrix),
# we can compose: sP ∘ left_inverse(tStS)
# This gives: tStS_addr → (m, k) → sP_addr
# But left_inverse(tStS) might not produce a simple (m, k) layout
# because tStS has non-compact strides.
# Alternative: use composition with matching domain shapes.
# tStS shape is (128, 128), sP_as_2d shape is (128, (16, 4, 2))
# They have the same logical domain: (128, 128) ≡ (128, (16, 4, 2))
# because 16*4*2 = 128
# CuTe's composition should handle this if the shapes are compatible.
# But composition(A, B) requires B's codomain to be A's domain.
# And the shapes need to match.
# Let me try a different approach: build the reindex layout directly.
# reindex: maps the tStS address space to the sP address space
# For each (m, k) pair:
# tStS_addr = 65536*m + k
# sP_addr = 64*m + k%16 + 16*(k//16%4) + 8192*(k//64)
# Since both are functions of (m, k), I can compose sP with left_inverse(tStS):
try:
# left_inverse(tStS_coalesced) maps tStS_addr → index in [0, 16384)
# But the "index" from left_inverse is the logical coordinate in tStS's domain
# For tStS with stride (65536, 1), left_inverse maps:
# addr → (addr // 65536, addr % 65536) = (m, k)
# But as a CuTe layout, the result might be (m*1 + k*something)
# Actually, left_inverse of a layout with stride (65536, 1) and shape (128, 128)
# gives a layout that maps addr → index (flat 1D or structured)
print("\n=== Trying composition chain ===")
# Step A: left_inverse(tStS_coalesced)
tStS_inv = cute.left_inverse(tStS_coalesced)
print(f"tStS_inv: {tStS_inv}")
# Step B: Compose sP with tStS_inv
# This gives: tStS_addr → (m, k) → sP_addr
# But we need to match the domain/codomain shapes
# tStS_inv maps to a coordinate space matching tStS's domain shape
# sP_as_2d maps from a coordinate space matching sP's domain shape
# If both have domain shape (128, 128) ≡ (128, (16,4,2)),
# composition should work
try:
reindex = cute.composition(sP_as_2d, tStS_inv)
print(f"reindex layout: {reindex}")
print(f"reindex shape: {cute.shape(reindex)}")
except Exception as e:
print(f"composition(sP, tStS_inv) FAILED: {e}")
# Try with coalesced versions
try:
reindex = cute.composition(sP_coalesced, tStS_inv)
print(f"reindex (coalesced) layout: {reindex}")
except Exception as e2:
print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}")
print(f"\ntStS_inv: {tStS_inv}")
print(f"tStS_inv shape: {cute.shape(tStS_inv)}")
except Exception as e:
print(f"Composition chain failed: {e}")
print(f"\ntStS left_inverse FAILED: {e}")
import traceback; traceback.print_exc()
return
# ===========================
# Step 7: If composition works, try make_cotiled_copy
# Try composition: sP ∘ tStS_inv → reindex layout
# ===========================
try:
reindex = cute.composition(sP_2d, tStS_inv)
print(f"\nreindex: {reindex}")
print(f"reindex shape: {cute.shape(reindex)}")
except Exception as e:
print(f"\ncomposition(sP, tStS_inv) FAILED: {e}")
import traceback; traceback.print_exc()
# Try with coalesced sP
try:
reindex = cute.composition(sP_coalesced, tStS_inv)
print(f"reindex (coalesced): {reindex}")
except Exception as e2:
print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}")
return
# ===========================
# Try composition: reindex ∘ dst_tv → atom_layout_tv
# ===========================
try:
# atom_layout_tv = composition(reindex, dst_tv)
# This maps (tid, vid) → tStS_addr → sP_addr
atom_layout_tv = cute.composition(reindex, dst_tv)
print(f"\natom_layout_tv: {atom_layout_tv}")
print(f"atom_layout_tv shape: {cute.shape(atom_layout_tv)}")
print(f"atom_layout_tv stride: {atom_layout_tv.stride}")
except Exception as e:
print(f"\ncomposition(reindex, dst_tv) FAILED: {e}")
import traceback; traceback.print_exc()
return
# Try make_cotiled_copy
# ===========================
# Try make_cotiled_copy
# ===========================
try:
r2s_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
@@ -231,46 +173,31 @@ def main():
)
tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced)
print(f"\nmake_cotiled_copy SUCCEEDED!")
print(f"tiled_r2s layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
print(f"tiled_r2s layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
# Try partition for thread 0
try:
thr_r2s = tiled_r2s.get_slice(0)
print(f" get_slice(0) SUCCEEDED")
except Exception as e:
print(f" get_slice(0) FAILED: {e}")
# Try partition
thr_r2s = tiled_r2s.get_slice(0)
print(f"partition_D and partition_S work!")
except Exception as e:
print(f"\nmake_cotiled_copy FAILED: {e}")
import traceback
traceback.print_exc()
import traceback; traceback.print_exc()
# ===========================
# Alternative: Try make_tiled_copy_tv
# ===========================
# We know the softmax threads (128 total) and each thread has 128 values
# thr_layout: (128, 1) → tid (or some other layout matching the TMEM-load)
# val_layout: (32, 4) → vid (from the coordinate shape ((32,1),4))
# The TMEM-load's TV layout tells us the thread and value structure
print(f"\n=== TV layout analysis ===")
tv = dst_tv
# If tv has shape (128, 128), that's 128 threads × 128 values
# The thread layout maps (tid_group) → tid, and val layout maps (vid_group) → vid
# But we need to decompose the TV layout into separate thread and value parts
# For make_tiled_copy_tv, we need:
# thr_layout: mapping from tile coordinates to thread IDs
# val_layout: mapping from value coordinates to value IDs
# These should be "compact" (i.e., the layout produces consecutive thread/value IDs)
# The TV layout for our TMEM load likely has:
# 128 threads (4 warps × 32 threads)
# 128 values per thread (32×4 from coordinate shape)
# If the TV layout has shape ((A, B), (C, D)) with some nesting,
# we need to figure out which dimensions are "thread" and which are "value"
print(f"Full TV layout: {tv}")
print(f"TV shape raw: {tv.shape}")
print(f"TV stride raw: {tv.stride}")
# Try with 128-bit vector width
try:
r2s_atom_128 = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
num_bits_per_copy=128,
)
tiled_r2s_128 = cute.make_cotiled_copy(r2s_atom_128, atom_layout_tv, sP_coalesced)
print(f"\nmake_cotiled_copy with 128-bit SUCCEEDED!")
except Exception as e2:
print(f"make_cotiled_copy with 128-bit ALSO FAILED: {e2}")
if __name__ == '__main__':