From a5fbc6cebd23c384dc4eeedea94adba36aac62e4 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 01:53:24 +0000 Subject: [PATCH] diag: simplified cotiled layout test --- tests/unit/test_cotiled_diag.py | 281 ++++++++++++-------------------- 1 file changed, 104 insertions(+), 177 deletions(-) diff --git a/tests/unit/test_cotiled_diag.py b/tests/unit/test_cotiled_diag.py index 965ac300..05bf0140 100644 --- a/tests/unit/test_cotiled_diag.py +++ b/tests/unit/test_cotiled_diag.py @@ -1,47 +1,66 @@ -"""Minimal diagnostic: can we build atom_layout_tv for make_cotiled_copy? +"""Minimal diagnostic: test layout composition for SMEM-P make_cotiled_copy. -Tests the layout composition chain: - layout_dst_tv_tiled: (tid, vid) → tStS0_addr - left_inverse(tStS_coalesced): tStS0_addr → (m, k) flat index - sP_reindexed: (m, k) flat index → sP_addr +Tests whether we can compose the TMEM-load TV layout with the sP address mapping +to build atom_layout_tv for make_cotiled_copy. -Result: atom_layout_tv = sP_reindexed ∘ left_inverse(tStS_coalesced) ∘ layout_dst_tv_tiled +This test uses the actual FmhaKernel setup to get the real layouts. """ -import torch, math +import torch, math, sys import cutlass, cutlass.cute as cute import cutlass.utils as utils from cutlass.cute.nvgpu import tcgen05 from cutlass import Float32, BFloat16 import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel def main(): head_dim = 256 s_k = 128 m = 128 - pv_n_tile = min(head_dim, 256) + # Use FmhaKernel to set up the same layouts q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda') v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + kernel = FmhaKernel(head_dim=head_dim, s_k=s_k) + + # Do the same setup as __call__ but extract layouts before launching + pv_n_tile = kernel.pv_n_tile + + v_tile = v[:, 0:pv_n_tile].contiguous() + v_kernel = v_tile.unsqueeze(-1) mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) - v_fmha = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1) - mV = ct.from_dlpack(v_fmha).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_fmha)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c[:, 0:pv_n_tile, :]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c[:, 0:pv_n_tile, :])) + # Reproduce the __call__ setup to extract the layouts + q_dtype = BFloat16 from cutlass.utils import LayoutEnum a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() b_major = LayoutEnum.from_tensor(mK).mma_major_mode() - v_major = LayoutEnum.from_tensor(mV).mma_major_mode() + + # v_fmha layout + v_fmha_layout = cute.make_layout( + (pv_n_tile, s_k, 1), + stride=(1, pv_n_tile, pv_n_tile * s_k), + ) + v_major = LayoutEnum.ROW_MAJOR # based on the v_fmha layout qk_mma = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, a_major, b_major, Float32, + q_dtype, q_dtype, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM ) pv_source = tcgen05.OperandSource.SMEM pv_mma = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, a_major, v_major, Float32, + q_dtype, q_dtype, a_major, v_major, Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source ) @@ -50,180 +69,103 @@ def main(): pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik)) - p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) + p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, q_dtype, 1) qk_thr = qk_mma.get_slice(0) qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2]) tStS = qk_thr.make_fragment_C(qk_as) tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + # TMEM-load copy tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 ) tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) # =========================== - # Step 1: Get TV layout + # Print layouts # =========================== dst_tv = tiled_tmem_load.layout_dst_tv_tiled print(f"dst_tv shape: {dst_tv.shape}") print(f"dst_tv stride: {dst_tv.stride}") print(f"dst_tv size: {cute.size(dst_tv)}") - # =========================== - # Step 2: Get sP layout (coalesced, logical, no swizzle) - # =========================== sP_outer = p_smem_s.outer sP_coalesced = cute.coalesce(sP_outer) - print(f"\nsP outer shape: {cute.shape(sP_outer)}") - print(f"sP outer: {sP_outer}") - print(f"sP coalesced shape: {cute.shape(sP_coalesced)}") + print(f"\nsP outer: {sP_outer}") + print(f"sP outer shape: {cute.shape(sP_outer)}") print(f"sP coalesced: {sP_coalesced}") - # =========================== - # Step 3: Get tStS0 layout (coalesced) - # =========================== tStS_coalesced = cute.coalesce(tStS0.layout) print(f"\ntStS layout: {tStS0.layout}") print(f"tStS coalesced: {tStS_coalesced}") - print(f"tStS coalesced shape: {cute.shape(tStS_coalesced)}") # =========================== - # Step 4: Try left_inverse of tStS + # Build sP layout in (128, 128) coordinate space # =========================== - try: - tStS_inv = cute.left_inverse(tStS_coalesced) - print(f"\ntStS left_inverse: {tStS_inv}") - print(f"tStS left_inverse shape: {cute.shape(tStS_inv)}") - except Exception as e: - print(f"\ntStS left_inverse FAILED: {e}") - return - - # =========================== - # Step 5: Build sP in the same (m, k) coordinate space as tStS - # =========================== - # tStS has layout (128, 128) with stride (65536, 1) - # sP has layout ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0) - # - # We need to express sP as a layout with the same domain shape as tStS: (128, 128) - # i.e., sP_as_128x128: (m, k) → sP_addr - # - # The sP layout in decomposed form: - # (m, k0, k1, k2) → 64*m + 1*k0 + 16*k1 + 8192*k2 - # where k = k0 + 16*k1 + 64*k2 - # - # This is NOT a simple 2D layout because the k→(k0,k1,k2) decomposition - # is not affine in the traditional sense. - # - # BUT: CuTe layouts support hierarchical/nested shapes. - # We can express sP as: shape=((128, (16, 4, 2)), stride=((64, (1, 16, 8192)))) - # This is a layout with 2 modes: - # mode 0: (128,) with stride (64,) — the m dimension - # mode 1: (16, 4, 2) with stride (1, 16, 8192) — the k decomposed - - # Let me build this layout explicitly: - sP_as_2d = cute.make_layout( + # sP outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0) + # This is the same as (128, (16, 4, 2)) with strides (64, (1, 16, 8192)) + # Let me build it explicitly: + sP_2d = cute.make_layout( (128, (16, 4, 2)), stride=(64, (1, 16, 8192)) ) - print(f"\nsP as 2D layout: {sP_as_2d}") - print(f"sP as 2D shape: {cute.shape(sP_as_2d)}") + print(f"\nsP_2d: {sP_2d}") + print(f"sP_2d size: {cute.size(sP_2d)}") - # Verify: does this match sP_coalesced? - # sP_coalesced is coalesced from ((128,16),1,(4,2),1):((64,1),0,(16,8192),0) - # Coalescing removes size-1 modes and merges adjacent modes with compatible strides - # Result should be ((128, 16, 4, 2)):((64, 1, 16, 8192)) - # Or possibly ((128, (16, 4, 2))):((64, (1, 16, 8192))) - - # Let me verify by checking element counts - print(f"sP_as_2d size: {cute.size(sP_as_2d)} (should be 16384)") - print(f"sP_coalesced size: {cute.size(sP_coalesced)} (should be 16384)") + # tStS coalesced: (128, 128) with stride (65536, 1) typically + # Let me check the exact strides + print(f"\ntStS_coalesced shape: {cute.shape(tStS_coalesced)}") + print(f"tStS_coalesced stride: {tStS_coalesced.stride}") # =========================== - # Step 6: Build the address mapping + # Try left_inverse(tStS) # =========================== - # We have: - # tStS: (128, 128) → addr with stride (65536, 1) - # sP: (128, (16,4,2)) → addr with stride (64, (1, 16, 8192)) - # - # Both map (m, k) → address, but with different strides. - # We need a layout that maps tStS_addr → sP_addr. - # - # Since both have the SAME domain (the (m, k) space of the P matrix), - # we can compose: sP ∘ left_inverse(tStS) - # This gives: tStS_addr → (m, k) → sP_addr - - # But left_inverse(tStS) might not produce a simple (m, k) layout - # because tStS has non-compact strides. - - # Alternative: use composition with matching domain shapes. - # tStS shape is (128, 128), sP_as_2d shape is (128, (16, 4, 2)) - # They have the same logical domain: (128, 128) ≡ (128, (16, 4, 2)) - # because 16*4*2 = 128 - - # CuTe's composition should handle this if the shapes are compatible. - # But composition(A, B) requires B's codomain to be A's domain. - # And the shapes need to match. - - # Let me try a different approach: build the reindex layout directly. - # reindex: maps the tStS address space to the sP address space - # For each (m, k) pair: - # tStS_addr = 65536*m + k - # sP_addr = 64*m + k%16 + 16*(k//16%4) + 8192*(k//64) - - # Since both are functions of (m, k), I can compose sP with left_inverse(tStS): try: - # left_inverse(tStS_coalesced) maps tStS_addr → index in [0, 16384) - # But the "index" from left_inverse is the logical coordinate in tStS's domain - # For tStS with stride (65536, 1), left_inverse maps: - # addr → (addr // 65536, addr % 65536) = (m, k) - # But as a CuTe layout, the result might be (m*1 + k*something) - # Actually, left_inverse of a layout with stride (65536, 1) and shape (128, 128) - # gives a layout that maps addr → index (flat 1D or structured) - - print("\n=== Trying composition chain ===") - - # Step A: left_inverse(tStS_coalesced) tStS_inv = cute.left_inverse(tStS_coalesced) - print(f"tStS_inv: {tStS_inv}") - - # Step B: Compose sP with tStS_inv - # This gives: tStS_addr → (m, k) → sP_addr - # But we need to match the domain/codomain shapes - - # tStS_inv maps to a coordinate space matching tStS's domain shape - # sP_as_2d maps from a coordinate space matching sP's domain shape - # If both have domain shape (128, 128) ≡ (128, (16,4,2)), - # composition should work - - try: - reindex = cute.composition(sP_as_2d, tStS_inv) - print(f"reindex layout: {reindex}") - print(f"reindex shape: {cute.shape(reindex)}") - except Exception as e: - print(f"composition(sP, tStS_inv) FAILED: {e}") - - # Try with coalesced versions - try: - reindex = cute.composition(sP_coalesced, tStS_inv) - print(f"reindex (coalesced) layout: {reindex}") - except Exception as e2: - print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}") - + print(f"\ntStS_inv: {tStS_inv}") + print(f"tStS_inv shape: {cute.shape(tStS_inv)}") except Exception as e: - print(f"Composition chain failed: {e}") + print(f"\ntStS left_inverse FAILED: {e}") + import traceback; traceback.print_exc() + return # =========================== - # Step 7: If composition works, try make_cotiled_copy + # Try composition: sP ∘ tStS_inv → reindex layout + # =========================== + try: + reindex = cute.composition(sP_2d, tStS_inv) + print(f"\nreindex: {reindex}") + print(f"reindex shape: {cute.shape(reindex)}") + except Exception as e: + print(f"\ncomposition(sP, tStS_inv) FAILED: {e}") + import traceback; traceback.print_exc() + + # Try with coalesced sP + try: + reindex = cute.composition(sP_coalesced, tStS_inv) + print(f"reindex (coalesced): {reindex}") + except Exception as e2: + print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}") + return + + # =========================== + # Try composition: reindex ∘ dst_tv → atom_layout_tv # =========================== try: - # atom_layout_tv = composition(reindex, dst_tv) - # This maps (tid, vid) → tStS_addr → sP_addr atom_layout_tv = cute.composition(reindex, dst_tv) print(f"\natom_layout_tv: {atom_layout_tv}") print(f"atom_layout_tv shape: {cute.shape(atom_layout_tv)}") + print(f"atom_layout_tv stride: {atom_layout_tv.stride}") + except Exception as e: + print(f"\ncomposition(reindex, dst_tv) FAILED: {e}") + import traceback; traceback.print_exc() + return - # Try make_cotiled_copy + # =========================== + # Try make_cotiled_copy + # =========================== + try: r2s_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), BFloat16, @@ -231,46 +173,31 @@ def main(): ) tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced) print(f"\nmake_cotiled_copy SUCCEEDED!") - print(f"tiled_r2s layout_tv_tiled: {tiled_r2s.layout_tv_tiled}") - print(f"tiled_r2s layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}") + print(f" layout_tv_tiled: {tiled_r2s.layout_tv_tiled}") + print(f" layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}") + + # Try partition for thread 0 + try: + thr_r2s = tiled_r2s.get_slice(0) + print(f" get_slice(0) SUCCEEDED") + except Exception as e: + print(f" get_slice(0) FAILED: {e}") - # Try partition - thr_r2s = tiled_r2s.get_slice(0) - print(f"partition_D and partition_S work!") except Exception as e: print(f"\nmake_cotiled_copy FAILED: {e}") - import traceback - traceback.print_exc() + import traceback; traceback.print_exc() - # =========================== - # Alternative: Try make_tiled_copy_tv - # =========================== - # We know the softmax threads (128 total) and each thread has 128 values - # thr_layout: (128, 1) → tid (or some other layout matching the TMEM-load) - # val_layout: (32, 4) → vid (from the coordinate shape ((32,1),4)) - - # The TMEM-load's TV layout tells us the thread and value structure - print(f"\n=== TV layout analysis ===") - tv = dst_tv - # If tv has shape (128, 128), that's 128 threads × 128 values - # The thread layout maps (tid_group) → tid, and val layout maps (vid_group) → vid - # But we need to decompose the TV layout into separate thread and value parts - - # For make_tiled_copy_tv, we need: - # thr_layout: mapping from tile coordinates to thread IDs - # val_layout: mapping from value coordinates to value IDs - # These should be "compact" (i.e., the layout produces consecutive thread/value IDs) - - # The TV layout for our TMEM load likely has: - # 128 threads (4 warps × 32 threads) - # 128 values per thread (32×4 from coordinate shape) - - # If the TV layout has shape ((A, B), (C, D)) with some nesting, - # we need to figure out which dimensions are "thread" and which are "value" - - print(f"Full TV layout: {tv}") - print(f"TV shape raw: {tv.shape}") - print(f"TV stride raw: {tv.stride}") + # Try with 128-bit vector width + try: + r2s_atom_128 = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + BFloat16, + num_bits_per_copy=128, + ) + tiled_r2s_128 = cute.make_cotiled_copy(r2s_atom_128, atom_layout_tv, sP_coalesced) + print(f"\nmake_cotiled_copy with 128-bit SUCCEEDED!") + except Exception as e2: + print(f"make_cotiled_copy with 128-bit ALSO FAILED: {e2}") if __name__ == '__main__':