Files
nvfp4-megamoe-kernel/tests/unit/test_cotiled_diag.py

205 lines
7.4 KiB
Python

"""Minimal diagnostic: test layout composition for SMEM-P make_cotiled_copy.
Tests whether we can compose the TMEM-load TV layout with the sP address mapping
to build atom_layout_tv for make_cotiled_copy.
This test uses the actual FmhaKernel setup to get the real layouts.
"""
import torch, math, sys
import cutlass, cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def main():
head_dim = 256
s_k = 128
m = 128
# Use FmhaKernel to set up the same layouts
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaKernel(head_dim=head_dim, s_k=s_k)
# Do the same setup as __call__ but extract layouts before launching
pv_n_tile = kernel.pv_n_tile
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c[:, 0:pv_n_tile, :]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c[:, 0:pv_n_tile, :]))
# Reproduce the __call__ setup to extract the layouts
q_dtype = BFloat16
from cutlass.utils import LayoutEnum
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
# v_fmha layout
v_fmha_layout = cute.make_layout(
(pv_n_tile, s_k, 1),
stride=(1, pv_n_tile, pv_n_tile * s_k),
)
v_major = LayoutEnum.ROW_MAJOR # based on the v_fmha layout
qk_mma = utils.sm100.make_trivial_tiled_mma(
q_dtype, q_dtype, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
)
pv_source = tcgen05.OperandSource.SMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(
q_dtype, q_dtype, a_major, v_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source
)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, q_dtype, 1)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
# TMEM-load copy
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
# ===========================
# Print layouts
# ===========================
dst_tv = tiled_tmem_load.layout_dst_tv_tiled
print(f"dst_tv shape: {dst_tv.shape}")
print(f"dst_tv stride: {dst_tv.stride}")
print(f"dst_tv size: {cute.size(dst_tv)}")
sP_outer = p_smem_s.outer
sP_coalesced = cute.coalesce(sP_outer)
print(f"\nsP outer: {sP_outer}")
print(f"sP outer shape: {cute.shape(sP_outer)}")
print(f"sP coalesced: {sP_coalesced}")
tStS_coalesced = cute.coalesce(tStS0.layout)
print(f"\ntStS layout: {tStS0.layout}")
print(f"tStS coalesced: {tStS_coalesced}")
# ===========================
# Build sP layout in (128, 128) coordinate space
# ===========================
# sP outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
# This is the same as (128, (16, 4, 2)) with strides (64, (1, 16, 8192))
# Let me build it explicitly:
sP_2d = cute.make_layout(
(128, (16, 4, 2)),
stride=(64, (1, 16, 8192))
)
print(f"\nsP_2d: {sP_2d}")
print(f"sP_2d size: {cute.size(sP_2d)}")
# tStS coalesced: (128, 128) with stride (65536, 1) typically
# Let me check the exact strides
print(f"\ntStS_coalesced shape: {cute.shape(tStS_coalesced)}")
print(f"tStS_coalesced stride: {tStS_coalesced.stride}")
# ===========================
# Try left_inverse(tStS)
# ===========================
try:
tStS_inv = cute.left_inverse(tStS_coalesced)
print(f"\ntStS_inv: {tStS_inv}")
print(f"tStS_inv shape: {cute.shape(tStS_inv)}")
except Exception as e:
print(f"\ntStS left_inverse FAILED: {e}")
import traceback; traceback.print_exc()
return
# ===========================
# Try composition: sP ∘ tStS_inv → reindex layout
# ===========================
try:
reindex = cute.composition(sP_2d, tStS_inv)
print(f"\nreindex: {reindex}")
print(f"reindex shape: {cute.shape(reindex)}")
except Exception as e:
print(f"\ncomposition(sP, tStS_inv) FAILED: {e}")
import traceback; traceback.print_exc()
# Try with coalesced sP
try:
reindex = cute.composition(sP_coalesced, tStS_inv)
print(f"reindex (coalesced): {reindex}")
except Exception as e2:
print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}")
return
# ===========================
# Try composition: reindex ∘ dst_tv → atom_layout_tv
# ===========================
try:
atom_layout_tv = cute.composition(reindex, dst_tv)
print(f"\natom_layout_tv: {atom_layout_tv}")
print(f"atom_layout_tv shape: {cute.shape(atom_layout_tv)}")
print(f"atom_layout_tv stride: {atom_layout_tv.stride}")
except Exception as e:
print(f"\ncomposition(reindex, dst_tv) FAILED: {e}")
import traceback; traceback.print_exc()
return
# ===========================
# Try make_cotiled_copy
# ===========================
try:
r2s_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
num_bits_per_copy=16,
)
tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced)
print(f"\nmake_cotiled_copy SUCCEEDED!")
print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
# Try partition for thread 0
try:
thr_r2s = tiled_r2s.get_slice(0)
print(f" get_slice(0) SUCCEEDED")
except Exception as e:
print(f" get_slice(0) FAILED: {e}")
except Exception as e:
print(f"\nmake_cotiled_copy FAILED: {e}")
import traceback; traceback.print_exc()
# Try with 128-bit vector width
try:
r2s_atom_128 = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
num_bits_per_copy=128,
)
tiled_r2s_128 = cute.make_cotiled_copy(r2s_atom_128, atom_layout_tv, sP_coalesced)
print(f"\nmake_cotiled_copy with 128-bit SUCCEEDED!")
except Exception as e2:
print(f"make_cotiled_copy with 128-bit ALSO FAILED: {e2}")
if __name__ == '__main__':
main()