Files
nvfp4-megamoe-kernel/tests/unit/test_cotiled_diag.py

189 lines
6.9 KiB
Python

"""Minimal diagnostic: test layout composition for SMEM-P make_cotiled_copy.
Uses FmhaKernel's __call__ path to set up all layouts, then extracts
the TV layout and sP layout at the right point.
"""
import torch, math
import cutlass, cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Int32
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def main():
head_dim = 256
s_k = 128
m = 128
pv_n_tile = min(head_dim, 256)
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
kernel = FmhaKernel(head_dim=head_dim, s_k=s_k)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Reproduce the EXACT __call__ setup to get the MMA objects
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c[:, 0:pv_n_tile, :]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c[:, 0:pv_n_tile, :]))
# Derive major modes exactly as FmhaKernel does
from cutlass.utils import LayoutEnum
from cutlass.cute.nvgpu import OperandMajorMode
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
print(f"a_major: {a_major}, b_major: {b_major}, v_major: {v_major}") # layout (256, 128, 1) stride (1, 256, 32768) = col-major
c_layout = LayoutEnum.from_tensor(mC)
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM
)
pv_a_major = a_major # SMEM-P path
pv_source = tcgen05.OperandSource.SMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, pv_a_major, v_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source
)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
# sP layout (PV A-operand SMEM)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
# QK C-fragment
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
# TMEM-load copy
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
# ===========================
# Print layouts
# ===========================
dst_tv = tiled_tmem_load.layout_dst_tv_tiled
print(f"1. dst_tv shape: {dst_tv.shape}")
print(f" dst_tv stride: {dst_tv.stride}")
print(f" dst_tv: {dst_tv}")
sP_outer = p_smem_s.outer
sP_coalesced = cute.coalesce(sP_outer)
print(f"\n2. sP outer shape: {cute.shape(sP_outer)}")
print(f" sP outer: {sP_outer}")
print(f" sP coalesced: {sP_coalesced}")
tStS_coalesced = cute.coalesce(tStS0.layout)
print(f"\n3. tStS layout: {tStS0.layout}")
print(f" tStS coalesced: {tStS_coalesced}")
print(f" tStS coalesced shape: {cute.shape(tStS_coalesced)}")
# ===========================
# Build sP layout in (128, 128) coordinate space
# ===========================
# sP outer shape is ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
# Equivalent to (128, (16, 4, 2)) with strides (64, (1, 16, 8192))
sP_2d = cute.make_layout(
(128, (16, 4, 2)),
stride=(64, (1, 16, 8192))
)
print(f"\n4. sP_2d: {sP_2d}")
print(f" sP_2d size: {cute.size(sP_2d)}")
# ===========================
# Try left_inverse(tStS_coalesced)
# ===========================
print(f"\n5. Attempting left_inverse(tStS_coalesced)...")
try:
tStS_inv = cute.left_inverse(tStS_coalesced)
print(f" tStS_inv: {tStS_inv}")
print(f" tStS_inv shape: {cute.shape(tStS_inv)}")
except Exception as e:
print(f" FAILED: {e}")
import traceback; traceback.print_exc()
return
# ===========================
# Try composition: sP_2d ∘ tStS_inv → reindex
# ===========================
print(f"\n6. Attempting composition(sP_2d, tStS_inv)...")
reindex = None
try:
reindex = cute.composition(sP_2d, tStS_inv)
print(f" reindex: {reindex}")
print(f" reindex shape: {cute.shape(reindex)}")
except Exception as e:
print(f" FAILED: {e}")
# Try with sP_coalesced instead
try:
reindex = cute.composition(sP_coalesced, tStS_inv)
print(f" reindex (coalesced): {reindex}")
except Exception as e2:
print(f" ALSO FAILED: {e2}")
import traceback; traceback.print_exc()
return
# ===========================
# Try composition: reindex ∘ dst_tv → atom_layout_tv
# ===========================
print(f"\n7. Attempting composition(reindex, dst_tv)...")
atom_layout_tv = None
try:
atom_layout_tv = cute.composition(reindex, dst_tv)
print(f" atom_layout_tv: {atom_layout_tv}")
print(f" atom_layout_tv shape: {cute.shape(atom_layout_tv)}")
print(f" atom_layout_tv stride: {atom_layout_tv.stride}")
except Exception as e:
print(f" FAILED: {e}")
import traceback; traceback.print_exc()
return
# ===========================
# Try make_cotiled_copy
# ===========================
print(f"\n8. Attempting make_cotiled_copy...")
try:
r2s_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
num_bits_per_copy=16,
)
tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced)
print(f" make_cotiled_copy SUCCEEDED!")
print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}")
print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}")
# Try get_slice
try:
thr_r2s = tiled_r2s.get_slice(0)
print(f" get_slice(0) SUCCEEDED!")
except Exception as e:
print(f" get_slice(0) FAILED: {e}")
except Exception as e:
print(f" FAILED: {e}")
import traceback; traceback.print_exc()
if __name__ == '__main__':
main()