From 34fe43a5519da14ef9979400fb7c015dcb833311 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 01:54:08 +0000 Subject: [PATCH] diag: use FmhaKernel setup for cotiled test --- tests/unit/test_cotiled_diag.py | 145 +++++++++++++++----------------- 1 file changed, 68 insertions(+), 77 deletions(-) diff --git a/tests/unit/test_cotiled_diag.py b/tests/unit/test_cotiled_diag.py index 6fd274e0..e58c106a 100644 --- a/tests/unit/test_cotiled_diag.py +++ b/tests/unit/test_cotiled_diag.py @@ -1,15 +1,13 @@ """Minimal diagnostic: test layout composition for SMEM-P make_cotiled_copy. -Tests whether we can compose the TMEM-load TV layout with the sP address mapping -to build atom_layout_tv for make_cotiled_copy. - -This test uses the actual FmhaKernel setup to get the real layouts. +Uses FmhaKernel's __call__ path to set up all layouts, then extracts +the TV layout and sP layout at the right point. """ -import torch, math, sys +import torch, math import cutlass, cutlass.cute as cute import cutlass.utils as utils from cutlass.cute.nvgpu import tcgen05 -from cutlass import Float32, BFloat16 +from cutlass import Float32, BFloat16, Int32 import cutlass.torch as ct import cuda.bindings.driver as cuda from dsv4.kernels.attention.fmha import FmhaKernel @@ -19,20 +17,18 @@ def main(): head_dim = 256 s_k = 128 m = 128 + pv_n_tile = min(head_dim, 256) - # Use FmhaKernel to set up the same layouts q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda') v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda') c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - kernel = FmhaKernel(head_dim=head_dim, s_k=s_k) - # Do the same setup as __call__ but extract layouts before launching - pv_n_tile = kernel.pv_n_tile + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + # Reproduce the EXACT __call__ setup to get the MMA objects v_tile = v[:, 0:pv_n_tile].contiguous() v_kernel = v_tile.unsqueeze(-1) @@ -41,26 +37,31 @@ def main(): mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c[:, 0:pv_n_tile, :]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c[:, 0:pv_n_tile, :])) - # Reproduce the __call__ setup to extract the layouts - q_dtype = BFloat16 + # Derive major modes exactly as FmhaKernel does from cutlass.utils import LayoutEnum a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() b_major = LayoutEnum.from_tensor(mK).mma_major_mode() - - # v_major: must match FmhaKernel's derivation - # In FmhaKernel, v_fmha is created with layout (pv_n_tile, s_k, 1) stride (1, pv_n_tile, ...) - # We create a temporary tensor to get the LayoutEnum - v_fmha_tensor = torch.randn(pv_n_tile, s_k, 1, dtype=torch.bfloat16, device='cuda') - mV_fmha = ct.from_dlpack(v_fmha_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_fmha_tensor)) - v_major = LayoutEnum.from_tensor(mV_fmha).mma_major_mode() + + # V FMHA layout (same as FmhaKernel.__call__) + v_fmha = cute.make_tensor( + mV.iterator, + cute.make_layout( + (pv_n_tile, s_k, 1), + stride=(1, pv_n_tile, pv_n_tile * s_k), + ), + ) + v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() + + c_layout = LayoutEnum.from_tensor(mC) qk_mma = utils.sm100.make_trivial_tiled_mma( - q_dtype, q_dtype, a_major, b_major, Float32, + BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM ) + pv_a_major = a_major # SMEM-P path pv_source = tcgen05.OperandSource.SMEM pv_mma = utils.sm100.make_trivial_tiled_mma( - q_dtype, q_dtype, a_major, v_major, Float32, + BFloat16, BFloat16, pv_a_major, v_major, Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source ) @@ -69,8 +70,10 @@ def main(): pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik)) - p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, q_dtype, 1) + # sP layout (PV A-operand SMEM) + p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) + # QK C-fragment qk_thr = qk_mma.get_slice(0) qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2]) tStS = qk_thr.make_fragment_C(qk_as) @@ -86,85 +89,85 @@ def main(): # Print layouts # =========================== dst_tv = tiled_tmem_load.layout_dst_tv_tiled - print(f"dst_tv shape: {dst_tv.shape}") - print(f"dst_tv stride: {dst_tv.stride}") - print(f"dst_tv size: {cute.size(dst_tv)}") + print(f"1. dst_tv shape: {dst_tv.shape}") + print(f" dst_tv stride: {dst_tv.stride}") + print(f" dst_tv: {dst_tv}") sP_outer = p_smem_s.outer sP_coalesced = cute.coalesce(sP_outer) - print(f"\nsP outer: {sP_outer}") - print(f"sP outer shape: {cute.shape(sP_outer)}") - print(f"sP coalesced: {sP_coalesced}") + print(f"\n2. sP outer shape: {cute.shape(sP_outer)}") + print(f" sP outer: {sP_outer}") + print(f" sP coalesced: {sP_coalesced}") tStS_coalesced = cute.coalesce(tStS0.layout) - print(f"\ntStS layout: {tStS0.layout}") - print(f"tStS coalesced: {tStS_coalesced}") + print(f"\n3. tStS layout: {tStS0.layout}") + print(f" tStS coalesced: {tStS_coalesced}") + print(f" tStS coalesced shape: {cute.shape(tStS_coalesced)}") # =========================== # Build sP layout in (128, 128) coordinate space # =========================== - # sP outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0) - # This is the same as (128, (16, 4, 2)) with strides (64, (1, 16, 8192)) - # Let me build it explicitly: + # sP outer shape is ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0) + # Equivalent to (128, (16, 4, 2)) with strides (64, (1, 16, 8192)) sP_2d = cute.make_layout( (128, (16, 4, 2)), stride=(64, (1, 16, 8192)) ) - print(f"\nsP_2d: {sP_2d}") - print(f"sP_2d size: {cute.size(sP_2d)}") - - # tStS coalesced: (128, 128) with stride (65536, 1) typically - # Let me check the exact strides - print(f"\ntStS_coalesced shape: {cute.shape(tStS_coalesced)}") - print(f"tStS_coalesced stride: {tStS_coalesced.stride}") + print(f"\n4. sP_2d: {sP_2d}") + print(f" sP_2d size: {cute.size(sP_2d)}") # =========================== - # Try left_inverse(tStS) + # Try left_inverse(tStS_coalesced) # =========================== + print(f"\n5. Attempting left_inverse(tStS_coalesced)...") try: tStS_inv = cute.left_inverse(tStS_coalesced) - print(f"\ntStS_inv: {tStS_inv}") - print(f"tStS_inv shape: {cute.shape(tStS_inv)}") + print(f" tStS_inv: {tStS_inv}") + print(f" tStS_inv shape: {cute.shape(tStS_inv)}") except Exception as e: - print(f"\ntStS left_inverse FAILED: {e}") + print(f" FAILED: {e}") import traceback; traceback.print_exc() return # =========================== - # Try composition: sP ∘ tStS_inv → reindex layout + # Try composition: sP_2d ∘ tStS_inv → reindex # =========================== + print(f"\n6. Attempting composition(sP_2d, tStS_inv)...") + reindex = None try: reindex = cute.composition(sP_2d, tStS_inv) - print(f"\nreindex: {reindex}") - print(f"reindex shape: {cute.shape(reindex)}") + print(f" reindex: {reindex}") + print(f" reindex shape: {cute.shape(reindex)}") except Exception as e: - print(f"\ncomposition(sP, tStS_inv) FAILED: {e}") - import traceback; traceback.print_exc() - - # Try with coalesced sP + print(f" FAILED: {e}") + # Try with sP_coalesced instead try: reindex = cute.composition(sP_coalesced, tStS_inv) - print(f"reindex (coalesced): {reindex}") + print(f" reindex (coalesced): {reindex}") except Exception as e2: - print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}") + print(f" ALSO FAILED: {e2}") + import traceback; traceback.print_exc() return # =========================== # Try composition: reindex ∘ dst_tv → atom_layout_tv # =========================== + print(f"\n7. Attempting composition(reindex, dst_tv)...") + atom_layout_tv = None try: atom_layout_tv = cute.composition(reindex, dst_tv) - print(f"\natom_layout_tv: {atom_layout_tv}") - print(f"atom_layout_tv shape: {cute.shape(atom_layout_tv)}") - print(f"atom_layout_tv stride: {atom_layout_tv.stride}") + print(f" atom_layout_tv: {atom_layout_tv}") + print(f" atom_layout_tv shape: {cute.shape(atom_layout_tv)}") + print(f" atom_layout_tv stride: {atom_layout_tv.stride}") except Exception as e: - print(f"\ncomposition(reindex, dst_tv) FAILED: {e}") + print(f" FAILED: {e}") import traceback; traceback.print_exc() return # =========================== # Try make_cotiled_copy # =========================== + print(f"\n8. Attempting make_cotiled_copy...") try: r2s_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), @@ -172,33 +175,21 @@ def main(): num_bits_per_copy=16, ) tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced) - print(f"\nmake_cotiled_copy SUCCEEDED!") - print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}") - print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}") + print(f" make_cotiled_copy SUCCEEDED!") + print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}") + print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}") - # Try partition for thread 0 + # Try get_slice try: thr_r2s = tiled_r2s.get_slice(0) - print(f" get_slice(0) SUCCEEDED") + print(f" get_slice(0) SUCCEEDED!") except Exception as e: - print(f" get_slice(0) FAILED: {e}") + print(f" get_slice(0) FAILED: {e}") except Exception as e: - print(f"\nmake_cotiled_copy FAILED: {e}") + print(f" FAILED: {e}") import traceback; traceback.print_exc() - # Try with 128-bit vector width - try: - r2s_atom_128 = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - BFloat16, - num_bits_per_copy=128, - ) - tiled_r2s_128 = cute.make_cotiled_copy(r2s_atom_128, atom_layout_tv, sP_coalesced) - print(f"\nmake_cotiled_copy with 128-bit SUCCEEDED!") - except Exception as e2: - print(f"make_cotiled_copy with 128-bit ALSO FAILED: {e2}") - if __name__ == '__main__': main()