diag: use FmhaKernel setup for cotiled test
This commit is contained in:
@@ -1,15 +1,13 @@
|
||||
"""Minimal diagnostic: test layout composition for SMEM-P make_cotiled_copy.
|
||||
|
||||
Tests whether we can compose the TMEM-load TV layout with the sP address mapping
|
||||
to build atom_layout_tv for make_cotiled_copy.
|
||||
|
||||
This test uses the actual FmhaKernel setup to get the real layouts.
|
||||
Uses FmhaKernel's __call__ path to set up all layouts, then extracts
|
||||
the TV layout and sP layout at the right point.
|
||||
"""
|
||||
import torch, math, sys
|
||||
import torch, math
|
||||
import cutlass, cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass import Float32, BFloat16
|
||||
from cutlass import Float32, BFloat16, Int32
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
@@ -19,20 +17,18 @@ def main():
|
||||
head_dim = 256
|
||||
s_k = 128
|
||||
m = 128
|
||||
pv_n_tile = min(head_dim, 256)
|
||||
|
||||
# Use FmhaKernel to set up the same layouts
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
kernel = FmhaKernel(head_dim=head_dim, s_k=s_k)
|
||||
|
||||
# Do the same setup as __call__ but extract layouts before launching
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
# Reproduce the EXACT __call__ setup to get the MMA objects
|
||||
v_tile = v[:, 0:pv_n_tile].contiguous()
|
||||
v_kernel = v_tile.unsqueeze(-1)
|
||||
|
||||
@@ -41,26 +37,31 @@ def main():
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c[:, 0:pv_n_tile, :]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c[:, 0:pv_n_tile, :]))
|
||||
|
||||
# Reproduce the __call__ setup to extract the layouts
|
||||
q_dtype = BFloat16
|
||||
# Derive major modes exactly as FmhaKernel does
|
||||
from cutlass.utils import LayoutEnum
|
||||
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
|
||||
# v_major: must match FmhaKernel's derivation
|
||||
# In FmhaKernel, v_fmha is created with layout (pv_n_tile, s_k, 1) stride (1, pv_n_tile, ...)
|
||||
# We create a temporary tensor to get the LayoutEnum
|
||||
v_fmha_tensor = torch.randn(pv_n_tile, s_k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
mV_fmha = ct.from_dlpack(v_fmha_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_fmha_tensor))
|
||||
v_major = LayoutEnum.from_tensor(mV_fmha).mma_major_mode()
|
||||
|
||||
# V FMHA layout (same as FmhaKernel.__call__)
|
||||
v_fmha = cute.make_tensor(
|
||||
mV.iterator,
|
||||
cute.make_layout(
|
||||
(pv_n_tile, s_k, 1),
|
||||
stride=(1, pv_n_tile, pv_n_tile * s_k),
|
||||
),
|
||||
)
|
||||
v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
|
||||
|
||||
c_layout = LayoutEnum.from_tensor(mC)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
q_dtype, q_dtype, a_major, b_major, Float32,
|
||||
BFloat16, BFloat16, a_major, b_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
|
||||
)
|
||||
pv_a_major = a_major # SMEM-P path
|
||||
pv_source = tcgen05.OperandSource.SMEM
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
q_dtype, q_dtype, a_major, v_major, Float32,
|
||||
BFloat16, BFloat16, pv_a_major, v_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source
|
||||
)
|
||||
|
||||
@@ -69,8 +70,10 @@ def main():
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
|
||||
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, q_dtype, 1)
|
||||
# sP layout (PV A-operand SMEM)
|
||||
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
|
||||
# QK C-fragment
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
@@ -86,85 +89,85 @@ def main():
|
||||
# Print layouts
|
||||
# ===========================
|
||||
dst_tv = tiled_tmem_load.layout_dst_tv_tiled
|
||||
print(f"dst_tv shape: {dst_tv.shape}")
|
||||
print(f"dst_tv stride: {dst_tv.stride}")
|
||||
print(f"dst_tv size: {cute.size(dst_tv)}")
|
||||
print(f"1. dst_tv shape: {dst_tv.shape}")
|
||||
print(f" dst_tv stride: {dst_tv.stride}")
|
||||
print(f" dst_tv: {dst_tv}")
|
||||
|
||||
sP_outer = p_smem_s.outer
|
||||
sP_coalesced = cute.coalesce(sP_outer)
|
||||
print(f"\nsP outer: {sP_outer}")
|
||||
print(f"sP outer shape: {cute.shape(sP_outer)}")
|
||||
print(f"sP coalesced: {sP_coalesced}")
|
||||
print(f"\n2. sP outer shape: {cute.shape(sP_outer)}")
|
||||
print(f" sP outer: {sP_outer}")
|
||||
print(f" sP coalesced: {sP_coalesced}")
|
||||
|
||||
tStS_coalesced = cute.coalesce(tStS0.layout)
|
||||
print(f"\ntStS layout: {tStS0.layout}")
|
||||
print(f"tStS coalesced: {tStS_coalesced}")
|
||||
print(f"\n3. tStS layout: {tStS0.layout}")
|
||||
print(f" tStS coalesced: {tStS_coalesced}")
|
||||
print(f" tStS coalesced shape: {cute.shape(tStS_coalesced)}")
|
||||
|
||||
# ===========================
|
||||
# Build sP layout in (128, 128) coordinate space
|
||||
# ===========================
|
||||
# sP outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
|
||||
# This is the same as (128, (16, 4, 2)) with strides (64, (1, 16, 8192))
|
||||
# Let me build it explicitly:
|
||||
# sP outer shape is ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
|
||||
# Equivalent to (128, (16, 4, 2)) with strides (64, (1, 16, 8192))
|
||||
sP_2d = cute.make_layout(
|
||||
(128, (16, 4, 2)),
|
||||
stride=(64, (1, 16, 8192))
|
||||
)
|
||||
print(f"\nsP_2d: {sP_2d}")
|
||||
print(f"sP_2d size: {cute.size(sP_2d)}")
|
||||
|
||||
# tStS coalesced: (128, 128) with stride (65536, 1) typically
|
||||
# Let me check the exact strides
|
||||
print(f"\ntStS_coalesced shape: {cute.shape(tStS_coalesced)}")
|
||||
print(f"tStS_coalesced stride: {tStS_coalesced.stride}")
|
||||
print(f"\n4. sP_2d: {sP_2d}")
|
||||
print(f" sP_2d size: {cute.size(sP_2d)}")
|
||||
|
||||
# ===========================
|
||||
# Try left_inverse(tStS)
|
||||
# Try left_inverse(tStS_coalesced)
|
||||
# ===========================
|
||||
print(f"\n5. Attempting left_inverse(tStS_coalesced)...")
|
||||
try:
|
||||
tStS_inv = cute.left_inverse(tStS_coalesced)
|
||||
print(f"\ntStS_inv: {tStS_inv}")
|
||||
print(f"tStS_inv shape: {cute.shape(tStS_inv)}")
|
||||
print(f" tStS_inv: {tStS_inv}")
|
||||
print(f" tStS_inv shape: {cute.shape(tStS_inv)}")
|
||||
except Exception as e:
|
||||
print(f"\ntStS left_inverse FAILED: {e}")
|
||||
print(f" FAILED: {e}")
|
||||
import traceback; traceback.print_exc()
|
||||
return
|
||||
|
||||
# ===========================
|
||||
# Try composition: sP ∘ tStS_inv → reindex layout
|
||||
# Try composition: sP_2d ∘ tStS_inv → reindex
|
||||
# ===========================
|
||||
print(f"\n6. Attempting composition(sP_2d, tStS_inv)...")
|
||||
reindex = None
|
||||
try:
|
||||
reindex = cute.composition(sP_2d, tStS_inv)
|
||||
print(f"\nreindex: {reindex}")
|
||||
print(f"reindex shape: {cute.shape(reindex)}")
|
||||
print(f" reindex: {reindex}")
|
||||
print(f" reindex shape: {cute.shape(reindex)}")
|
||||
except Exception as e:
|
||||
print(f"\ncomposition(sP, tStS_inv) FAILED: {e}")
|
||||
import traceback; traceback.print_exc()
|
||||
|
||||
# Try with coalesced sP
|
||||
print(f" FAILED: {e}")
|
||||
# Try with sP_coalesced instead
|
||||
try:
|
||||
reindex = cute.composition(sP_coalesced, tStS_inv)
|
||||
print(f"reindex (coalesced): {reindex}")
|
||||
print(f" reindex (coalesced): {reindex}")
|
||||
except Exception as e2:
|
||||
print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}")
|
||||
print(f" ALSO FAILED: {e2}")
|
||||
import traceback; traceback.print_exc()
|
||||
return
|
||||
|
||||
# ===========================
|
||||
# Try composition: reindex ∘ dst_tv → atom_layout_tv
|
||||
# ===========================
|
||||
print(f"\n7. Attempting composition(reindex, dst_tv)...")
|
||||
atom_layout_tv = None
|
||||
try:
|
||||
atom_layout_tv = cute.composition(reindex, dst_tv)
|
||||
print(f"\natom_layout_tv: {atom_layout_tv}")
|
||||
print(f"atom_layout_tv shape: {cute.shape(atom_layout_tv)}")
|
||||
print(f"atom_layout_tv stride: {atom_layout_tv.stride}")
|
||||
print(f" atom_layout_tv: {atom_layout_tv}")
|
||||
print(f" atom_layout_tv shape: {cute.shape(atom_layout_tv)}")
|
||||
print(f" atom_layout_tv stride: {atom_layout_tv.stride}")
|
||||
except Exception as e:
|
||||
print(f"\ncomposition(reindex, dst_tv) FAILED: {e}")
|
||||
print(f" FAILED: {e}")
|
||||
import traceback; traceback.print_exc()
|
||||
return
|
||||
|
||||
# ===========================
|
||||
# Try make_cotiled_copy
|
||||
# ===========================
|
||||
print(f"\n8. Attempting make_cotiled_copy...")
|
||||
try:
|
||||
r2s_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
@@ -172,33 +175,21 @@ def main():
|
||||
num_bits_per_copy=16,
|
||||
)
|
||||
tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced)
|
||||
print(f"\nmake_cotiled_copy SUCCEEDED!")
|
||||
print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
|
||||
print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
|
||||
print(f" make_cotiled_copy SUCCEEDED!")
|
||||
print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
|
||||
print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
|
||||
|
||||
# Try partition for thread 0
|
||||
# Try get_slice
|
||||
try:
|
||||
thr_r2s = tiled_r2s.get_slice(0)
|
||||
print(f" get_slice(0) SUCCEEDED")
|
||||
print(f" get_slice(0) SUCCEEDED!")
|
||||
except Exception as e:
|
||||
print(f" get_slice(0) FAILED: {e}")
|
||||
print(f" get_slice(0) FAILED: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nmake_cotiled_copy FAILED: {e}")
|
||||
print(f" FAILED: {e}")
|
||||
import traceback; traceback.print_exc()
|
||||
|
||||
# Try with 128-bit vector width
|
||||
try:
|
||||
r2s_atom_128 = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
BFloat16,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_r2s_128 = cute.make_cotiled_copy(r2s_atom_128, atom_layout_tv, sP_coalesced)
|
||||
print(f"\nmake_cotiled_copy with 128-bit SUCCEEDED!")
|
||||
except Exception as e2:
|
||||
print(f"make_cotiled_copy with 128-bit ALSO FAILED: {e2}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user