diag: use FmhaKernel setup for cotiled test

This commit is contained in:
2026-05-24 01:54:08 +00:00
parent 7a9881ab82
commit 34fe43a551

View File

@@ -1,15 +1,13 @@
"""Minimal diagnostic: test layout composition for SMEM-P make_cotiled_copy.
Tests whether we can compose the TMEM-load TV layout with the sP address mapping
to build atom_layout_tv for make_cotiled_copy.
This test uses the actual FmhaKernel setup to get the real layouts.
Uses FmhaKernel's __call__ path to set up all layouts, then extracts
the TV layout and sP layout at the right point.
"""
import torch, math, sys
import torch, math
import cutlass, cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass import Float32, BFloat16, Int32
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
@@ -19,20 +17,18 @@ def main():
head_dim = 256
s_k = 128
m = 128
pv_n_tile = min(head_dim, 256)
# Use FmhaKernel to set up the same layouts
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaKernel(head_dim=head_dim, s_k=s_k)
# Do the same setup as __call__ but extract layouts before launching
pv_n_tile = kernel.pv_n_tile
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Reproduce the EXACT __call__ setup to get the MMA objects
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
@@ -41,26 +37,31 @@ def main():
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c[:, 0:pv_n_tile, :]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c[:, 0:pv_n_tile, :]))
# Reproduce the __call__ setup to extract the layouts
q_dtype = BFloat16
# Derive major modes exactly as FmhaKernel does
from cutlass.utils import LayoutEnum
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
# v_major: must match FmhaKernel's derivation
# In FmhaKernel, v_fmha is created with layout (pv_n_tile, s_k, 1) stride (1, pv_n_tile, ...)
# We create a temporary tensor to get the LayoutEnum
v_fmha_tensor = torch.randn(pv_n_tile, s_k, 1, dtype=torch.bfloat16, device='cuda')
mV_fmha = ct.from_dlpack(v_fmha_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_fmha_tensor))
v_major = LayoutEnum.from_tensor(mV_fmha).mma_major_mode()
# V FMHA layout (same as FmhaKernel.__call__)
v_fmha = cute.make_tensor(
mV.iterator,
cute.make_layout(
(pv_n_tile, s_k, 1),
stride=(1, pv_n_tile, pv_n_tile * s_k),
),
)
v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
c_layout = LayoutEnum.from_tensor(mC)
qk_mma = utils.sm100.make_trivial_tiled_mma(
q_dtype, q_dtype, a_major, b_major, Float32,
BFloat16, BFloat16, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
)
pv_a_major = a_major # SMEM-P path
pv_source = tcgen05.OperandSource.SMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(
q_dtype, q_dtype, a_major, v_major, Float32,
BFloat16, BFloat16, pv_a_major, v_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source
)
@@ -69,8 +70,10 @@ def main():
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, q_dtype, 1)
# sP layout (PV A-operand SMEM)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
# QK C-fragment
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
@@ -86,85 +89,85 @@ def main():
# Print layouts
# ===========================
dst_tv = tiled_tmem_load.layout_dst_tv_tiled
print(f"dst_tv shape: {dst_tv.shape}")
print(f"dst_tv stride: {dst_tv.stride}")
print(f"dst_tv size: {cute.size(dst_tv)}")
print(f"1. dst_tv shape: {dst_tv.shape}")
print(f" dst_tv stride: {dst_tv.stride}")
print(f" dst_tv: {dst_tv}")
sP_outer = p_smem_s.outer
sP_coalesced = cute.coalesce(sP_outer)
print(f"\nsP outer: {sP_outer}")
print(f"sP outer shape: {cute.shape(sP_outer)}")
print(f"sP coalesced: {sP_coalesced}")
print(f"\n2. sP outer shape: {cute.shape(sP_outer)}")
print(f" sP outer: {sP_outer}")
print(f" sP coalesced: {sP_coalesced}")
tStS_coalesced = cute.coalesce(tStS0.layout)
print(f"\ntStS layout: {tStS0.layout}")
print(f"tStS coalesced: {tStS_coalesced}")
print(f"\n3. tStS layout: {tStS0.layout}")
print(f" tStS coalesced: {tStS_coalesced}")
print(f" tStS coalesced shape: {cute.shape(tStS_coalesced)}")
# ===========================
# Build sP layout in (128, 128) coordinate space
# ===========================
# sP outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
# This is the same as (128, (16, 4, 2)) with strides (64, (1, 16, 8192))
# Let me build it explicitly:
# sP outer shape is ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
# Equivalent to (128, (16, 4, 2)) with strides (64, (1, 16, 8192))
sP_2d = cute.make_layout(
(128, (16, 4, 2)),
stride=(64, (1, 16, 8192))
)
print(f"\nsP_2d: {sP_2d}")
print(f"sP_2d size: {cute.size(sP_2d)}")
# tStS coalesced: (128, 128) with stride (65536, 1) typically
# Let me check the exact strides
print(f"\ntStS_coalesced shape: {cute.shape(tStS_coalesced)}")
print(f"tStS_coalesced stride: {tStS_coalesced.stride}")
print(f"\n4. sP_2d: {sP_2d}")
print(f" sP_2d size: {cute.size(sP_2d)}")
# ===========================
# Try left_inverse(tStS)
# Try left_inverse(tStS_coalesced)
# ===========================
print(f"\n5. Attempting left_inverse(tStS_coalesced)...")
try:
tStS_inv = cute.left_inverse(tStS_coalesced)
print(f"\ntStS_inv: {tStS_inv}")
print(f"tStS_inv shape: {cute.shape(tStS_inv)}")
print(f" tStS_inv: {tStS_inv}")
print(f" tStS_inv shape: {cute.shape(tStS_inv)}")
except Exception as e:
print(f"\ntStS left_inverse FAILED: {e}")
print(f" FAILED: {e}")
import traceback; traceback.print_exc()
return
# ===========================
# Try composition: sP ∘ tStS_inv → reindex layout
# Try composition: sP_2d ∘ tStS_inv → reindex
# ===========================
print(f"\n6. Attempting composition(sP_2d, tStS_inv)...")
reindex = None
try:
reindex = cute.composition(sP_2d, tStS_inv)
print(f"\nreindex: {reindex}")
print(f"reindex shape: {cute.shape(reindex)}")
print(f" reindex: {reindex}")
print(f" reindex shape: {cute.shape(reindex)}")
except Exception as e:
print(f"\ncomposition(sP, tStS_inv) FAILED: {e}")
import traceback; traceback.print_exc()
# Try with coalesced sP
print(f" FAILED: {e}")
# Try with sP_coalesced instead
try:
reindex = cute.composition(sP_coalesced, tStS_inv)
print(f"reindex (coalesced): {reindex}")
print(f" reindex (coalesced): {reindex}")
except Exception as e2:
print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}")
print(f" ALSO FAILED: {e2}")
import traceback; traceback.print_exc()
return
# ===========================
# Try composition: reindex ∘ dst_tv → atom_layout_tv
# ===========================
print(f"\n7. Attempting composition(reindex, dst_tv)...")
atom_layout_tv = None
try:
atom_layout_tv = cute.composition(reindex, dst_tv)
print(f"\natom_layout_tv: {atom_layout_tv}")
print(f"atom_layout_tv shape: {cute.shape(atom_layout_tv)}")
print(f"atom_layout_tv stride: {atom_layout_tv.stride}")
print(f" atom_layout_tv: {atom_layout_tv}")
print(f" atom_layout_tv shape: {cute.shape(atom_layout_tv)}")
print(f" atom_layout_tv stride: {atom_layout_tv.stride}")
except Exception as e:
print(f"\ncomposition(reindex, dst_tv) FAILED: {e}")
print(f" FAILED: {e}")
import traceback; traceback.print_exc()
return
# ===========================
# Try make_cotiled_copy
# ===========================
print(f"\n8. Attempting make_cotiled_copy...")
try:
r2s_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
@@ -172,33 +175,21 @@ def main():
num_bits_per_copy=16,
)
tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced)
print(f"\nmake_cotiled_copy SUCCEEDED!")
print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
print(f" make_cotiled_copy SUCCEEDED!")
print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
# Try partition for thread 0
# Try get_slice
try:
thr_r2s = tiled_r2s.get_slice(0)
print(f" get_slice(0) SUCCEEDED")
print(f" get_slice(0) SUCCEEDED!")
except Exception as e:
print(f" get_slice(0) FAILED: {e}")
print(f" get_slice(0) FAILED: {e}")
except Exception as e:
print(f"\nmake_cotiled_copy FAILED: {e}")
print(f" FAILED: {e}")
import traceback; traceback.print_exc()
# Try with 128-bit vector width
try:
r2s_atom_128 = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
num_bits_per_copy=128,
)
tiled_r2s_128 = cute.make_cotiled_copy(r2s_atom_128, atom_layout_tv, sP_coalesced)
print(f"\nmake_cotiled_copy with 128-bit SUCCEEDED!")
except Exception as e2:
print(f"make_cotiled_copy with 128-bit ALSO FAILED: {e2}")
if __name__ == '__main__':
main()