189 lines
6.9 KiB
Python
189 lines
6.9 KiB
Python
"""Minimal diagnostic: test layout composition for SMEM-P make_cotiled_copy.
|
|
|
|
Uses FmhaKernel's __call__ path to set up all layouts, then extracts
|
|
the TV layout and sP layout at the right point.
|
|
"""
|
|
import torch, math
|
|
import cutlass, cutlass.cute as cute
|
|
import cutlass.utils as utils
|
|
from cutlass.cute.nvgpu import tcgen05
|
|
from cutlass import Float32, BFloat16, Int32
|
|
import cutlass.torch as ct
|
|
import cuda.bindings.driver as cuda
|
|
from dsv4.kernels.attention.fmha import FmhaKernel
|
|
|
|
|
|
def main():
|
|
head_dim = 256
|
|
s_k = 128
|
|
m = 128
|
|
pv_n_tile = min(head_dim, 256)
|
|
|
|
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
|
|
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
kernel = FmhaKernel(head_dim=head_dim, s_k=s_k)
|
|
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
# Reproduce the EXACT __call__ setup to get the MMA objects
|
|
v_tile = v[:, 0:pv_n_tile].contiguous()
|
|
v_kernel = v_tile.unsqueeze(-1)
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
|
mC = ct.from_dlpack(c[:, 0:pv_n_tile, :]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c[:, 0:pv_n_tile, :]))
|
|
|
|
# Derive major modes exactly as FmhaKernel does
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.cute.nvgpu import OperandMajorMode
|
|
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
|
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
|
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
|
|
print(f"a_major: {a_major}, b_major: {b_major}, v_major: {v_major}") # layout (256, 128, 1) stride (1, 256, 32768) = col-major
|
|
|
|
c_layout = LayoutEnum.from_tensor(mC)
|
|
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, a_major, b_major, Float32,
|
|
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
|
|
)
|
|
pv_a_major = a_major # SMEM-P path
|
|
pv_source = tcgen05.OperandSource.SMEM
|
|
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, pv_a_major, v_major, Float32,
|
|
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source
|
|
)
|
|
|
|
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
qk_mma_tiler = (128, 128, qk_ik * 4)
|
|
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
|
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
|
|
|
# sP layout (PV A-operand SMEM)
|
|
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
|
|
|
# QK C-fragment
|
|
qk_thr = qk_mma.get_slice(0)
|
|
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
|
tStS = qk_thr.make_fragment_C(qk_as)
|
|
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
|
|
|
# TMEM-load copy
|
|
tmem_load_atom = cute.make_copy_atom(
|
|
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
|
|
)
|
|
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
|
|
|
# ===========================
|
|
# Print layouts
|
|
# ===========================
|
|
dst_tv = tiled_tmem_load.layout_dst_tv_tiled
|
|
print(f"1. dst_tv shape: {dst_tv.shape}")
|
|
print(f" dst_tv stride: {dst_tv.stride}")
|
|
print(f" dst_tv: {dst_tv}")
|
|
|
|
sP_outer = p_smem_s.outer
|
|
sP_coalesced = cute.coalesce(sP_outer)
|
|
print(f"\n2. sP outer shape: {cute.shape(sP_outer)}")
|
|
print(f" sP outer: {sP_outer}")
|
|
print(f" sP coalesced: {sP_coalesced}")
|
|
|
|
tStS_coalesced = cute.coalesce(tStS0.layout)
|
|
print(f"\n3. tStS layout: {tStS0.layout}")
|
|
print(f" tStS coalesced: {tStS_coalesced}")
|
|
print(f" tStS coalesced shape: {cute.shape(tStS_coalesced)}")
|
|
|
|
# ===========================
|
|
# Build sP layout in (128, 128) coordinate space
|
|
# ===========================
|
|
# sP outer shape is ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
|
|
# Equivalent to (128, (16, 4, 2)) with strides (64, (1, 16, 8192))
|
|
sP_2d = cute.make_layout(
|
|
(128, (16, 4, 2)),
|
|
stride=(64, (1, 16, 8192))
|
|
)
|
|
print(f"\n4. sP_2d: {sP_2d}")
|
|
print(f" sP_2d size: {cute.size(sP_2d)}")
|
|
|
|
# ===========================
|
|
# Try left_inverse(tStS_coalesced)
|
|
# ===========================
|
|
print(f"\n5. Attempting left_inverse(tStS_coalesced)...")
|
|
try:
|
|
tStS_inv = cute.left_inverse(tStS_coalesced)
|
|
print(f" tStS_inv: {tStS_inv}")
|
|
print(f" tStS_inv shape: {cute.shape(tStS_inv)}")
|
|
except Exception as e:
|
|
print(f" FAILED: {e}")
|
|
import traceback; traceback.print_exc()
|
|
return
|
|
|
|
# ===========================
|
|
# Try composition: sP_2d ∘ tStS_inv → reindex
|
|
# ===========================
|
|
print(f"\n6. Attempting composition(sP_2d, tStS_inv)...")
|
|
reindex = None
|
|
try:
|
|
reindex = cute.composition(sP_2d, tStS_inv)
|
|
print(f" reindex: {reindex}")
|
|
print(f" reindex shape: {cute.shape(reindex)}")
|
|
except Exception as e:
|
|
print(f" FAILED: {e}")
|
|
# Try with sP_coalesced instead
|
|
try:
|
|
reindex = cute.composition(sP_coalesced, tStS_inv)
|
|
print(f" reindex (coalesced): {reindex}")
|
|
except Exception as e2:
|
|
print(f" ALSO FAILED: {e2}")
|
|
import traceback; traceback.print_exc()
|
|
return
|
|
|
|
# ===========================
|
|
# Try composition: reindex ∘ dst_tv → atom_layout_tv
|
|
# ===========================
|
|
print(f"\n7. Attempting composition(reindex, dst_tv)...")
|
|
atom_layout_tv = None
|
|
try:
|
|
atom_layout_tv = cute.composition(reindex, dst_tv)
|
|
print(f" atom_layout_tv: {atom_layout_tv}")
|
|
print(f" atom_layout_tv shape: {cute.shape(atom_layout_tv)}")
|
|
print(f" atom_layout_tv stride: {atom_layout_tv.stride}")
|
|
except Exception as e:
|
|
print(f" FAILED: {e}")
|
|
import traceback; traceback.print_exc()
|
|
return
|
|
|
|
# ===========================
|
|
# Try make_cotiled_copy
|
|
# ===========================
|
|
print(f"\n8. Attempting make_cotiled_copy...")
|
|
try:
|
|
r2s_atom = cute.make_copy_atom(
|
|
cute.nvgpu.CopyUniversalOp(),
|
|
BFloat16,
|
|
num_bits_per_copy=16,
|
|
)
|
|
tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced)
|
|
print(f" make_cotiled_copy SUCCEEDED!")
|
|
print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
|
|
print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
|
|
|
|
# Try get_slice
|
|
try:
|
|
thr_r2s = tiled_r2s.get_slice(0)
|
|
print(f" get_slice(0) SUCCEEDED!")
|
|
except Exception as e:
|
|
print(f" get_slice(0) FAILED: {e}")
|
|
|
|
except Exception as e:
|
|
print(f" FAILED: {e}")
|
|
import traceback; traceback.print_exc()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|