From 67a2c3ee722463fa09c84903ba5467fbeda3a5d3 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 01:48:42 +0000 Subject: [PATCH] diag: layout composition test for make_cotiled_copy SMEM-P --- tests/unit/test_cotiled_diag.py | 276 ++++++++++++++++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 tests/unit/test_cotiled_diag.py diff --git a/tests/unit/test_cotiled_diag.py b/tests/unit/test_cotiled_diag.py new file mode 100644 index 00000000..7350acf4 --- /dev/null +++ b/tests/unit/test_cotiled_diag.py @@ -0,0 +1,276 @@ +"""Minimal diagnostic: can we build atom_layout_tv for make_cotiled_copy? + +Tests the layout composition chain: + layout_dst_tv_tiled: (tid, vid) → tStS0_addr + left_inverse(tStS_coalesced): tStS0_addr → (m, k) flat index + sP_reindexed: (m, k) flat index → sP_addr + +Result: atom_layout_tv = sP_reindexed ∘ left_inverse(tStS_coalesced) ∘ layout_dst_tv_tiled +""" +import torch, math +import cutlass, cutlass.cute as cute +import cutlass.utils as utils +from cutlass.cute.nvgpu import tcgen05 +from cutlass import Float32, BFloat16 +import cutlass.torch as ct + + +def main(): + head_dim = 256 + s_k = 128 + m = 128 + pv_n_tile = min(head_dim, 256) + + q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, head_dim, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, head_dim, dtype=torch.bfloat16, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + v_fmha = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1) + mV = ct.from_dlpack(v_fmha).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_fmha)) + + a_major = cute.LayoutEnum.from_tensor(mQ).mma_major_mode() + b_major = cute.LayoutEnum.from_tensor(mK).mma_major_mode() + v_major = cute.LayoutEnum.from_tensor(mV).mma_major_mode() + + qk_mma = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, a_major, b_major, Float32, + tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM + ) + pv_source = tcgen05.OperandSource.SMEM + pv_mma = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, a_major, v_major, Float32, + tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source + ) + + qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) + qk_mma_tiler = (128, 128, qk_ik * 4) + pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) + pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik)) + + p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) + + qk_thr = qk_mma.get_slice(0) + qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) + + # =========================== + # Step 1: Get TV layout + # =========================== + dst_tv = tiled_tmem_load.layout_dst_tv_tiled + print(f"dst_tv shape: {dst_tv.shape}") + print(f"dst_tv stride: {dst_tv.stride}") + print(f"dst_tv size: {cute.size(dst_tv)}") + + # =========================== + # Step 2: Get sP layout (coalesced, logical, no swizzle) + # =========================== + sP_outer = p_smem_s.outer + sP_coalesced = cute.coalesce(sP_outer) + print(f"\nsP outer shape: {cute.shape(sP_outer)}") + print(f"sP outer: {sP_outer}") + print(f"sP coalesced shape: {cute.shape(sP_coalesced)}") + print(f"sP coalesced: {sP_coalesced}") + + # =========================== + # Step 3: Get tStS0 layout (coalesced) + # =========================== + tStS_coalesced = cute.coalesce(tStS0.layout) + print(f"\ntStS layout: {tStS0.layout}") + print(f"tStS coalesced: {tStS_coalesced}") + print(f"tStS coalesced shape: {cute.shape(tStS_coalesced)}") + + # =========================== + # Step 4: Try left_inverse of tStS + # =========================== + try: + tStS_inv = cute.left_inverse(tStS_coalesced) + print(f"\ntStS left_inverse: {tStS_inv}") + print(f"tStS left_inverse shape: {cute.shape(tStS_inv)}") + except Exception as e: + print(f"\ntStS left_inverse FAILED: {e}") + return + + # =========================== + # Step 5: Build sP in the same (m, k) coordinate space as tStS + # =========================== + # tStS has layout (128, 128) with stride (65536, 1) + # sP has layout ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0) + # + # We need to express sP as a layout with the same domain shape as tStS: (128, 128) + # i.e., sP_as_128x128: (m, k) → sP_addr + # + # The sP layout in decomposed form: + # (m, k0, k1, k2) → 64*m + 1*k0 + 16*k1 + 8192*k2 + # where k = k0 + 16*k1 + 64*k2 + # + # This is NOT a simple 2D layout because the k→(k0,k1,k2) decomposition + # is not affine in the traditional sense. + # + # BUT: CuTe layouts support hierarchical/nested shapes. + # We can express sP as: shape=((128, (16, 4, 2)), stride=((64, (1, 16, 8192)))) + # This is a layout with 2 modes: + # mode 0: (128,) with stride (64,) — the m dimension + # mode 1: (16, 4, 2) with stride (1, 16, 8192) — the k decomposed + + # Let me build this layout explicitly: + sP_as_2d = cute.make_layout( + (128, (16, 4, 2)), + stride=(64, (1, 16, 8192)) + ) + print(f"\nsP as 2D layout: {sP_as_2d}") + print(f"sP as 2D shape: {cute.shape(sP_as_2d)}") + + # Verify: does this match sP_coalesced? + # sP_coalesced is coalesced from ((128,16),1,(4,2),1):((64,1),0,(16,8192),0) + # Coalescing removes size-1 modes and merges adjacent modes with compatible strides + # Result should be ((128, 16, 4, 2)):((64, 1, 16, 8192)) + # Or possibly ((128, (16, 4, 2))):((64, (1, 16, 8192))) + + # Let me verify by checking element counts + print(f"sP_as_2d size: {cute.size(sP_as_2d)} (should be 16384)") + print(f"sP_coalesced size: {cute.size(sP_coalesced)} (should be 16384)") + + # =========================== + # Step 6: Build the address mapping + # =========================== + # We have: + # tStS: (128, 128) → addr with stride (65536, 1) + # sP: (128, (16,4,2)) → addr with stride (64, (1, 16, 8192)) + # + # Both map (m, k) → address, but with different strides. + # We need a layout that maps tStS_addr → sP_addr. + # + # Since both have the SAME domain (the (m, k) space of the P matrix), + # we can compose: sP ∘ left_inverse(tStS) + # This gives: tStS_addr → (m, k) → sP_addr + + # But left_inverse(tStS) might not produce a simple (m, k) layout + # because tStS has non-compact strides. + + # Alternative: use composition with matching domain shapes. + # tStS shape is (128, 128), sP_as_2d shape is (128, (16, 4, 2)) + # They have the same logical domain: (128, 128) ≡ (128, (16, 4, 2)) + # because 16*4*2 = 128 + + # CuTe's composition should handle this if the shapes are compatible. + # But composition(A, B) requires B's codomain to be A's domain. + # And the shapes need to match. + + # Let me try a different approach: build the reindex layout directly. + # reindex: maps the tStS address space to the sP address space + # For each (m, k) pair: + # tStS_addr = 65536*m + k + # sP_addr = 64*m + k%16 + 16*(k//16%4) + 8192*(k//64) + + # Since both are functions of (m, k), I can compose sP with left_inverse(tStS): + try: + # left_inverse(tStS_coalesced) maps tStS_addr → index in [0, 16384) + # But the "index" from left_inverse is the logical coordinate in tStS's domain + # For tStS with stride (65536, 1), left_inverse maps: + # addr → (addr // 65536, addr % 65536) = (m, k) + # But as a CuTe layout, the result might be (m*1 + k*something) + # Actually, left_inverse of a layout with stride (65536, 1) and shape (128, 128) + # gives a layout that maps addr → index (flat 1D or structured) + + print("\n=== Trying composition chain ===") + + # Step A: left_inverse(tStS_coalesced) + tStS_inv = cute.left_inverse(tStS_coalesced) + print(f"tStS_inv: {tStS_inv}") + + # Step B: Compose sP with tStS_inv + # This gives: tStS_addr → (m, k) → sP_addr + # But we need to match the domain/codomain shapes + + # tStS_inv maps to a coordinate space matching tStS's domain shape + # sP_as_2d maps from a coordinate space matching sP's domain shape + # If both have domain shape (128, 128) ≡ (128, (16,4,2)), + # composition should work + + try: + reindex = cute.composition(sP_as_2d, tStS_inv) + print(f"reindex layout: {reindex}") + print(f"reindex shape: {cute.shape(reindex)}") + except Exception as e: + print(f"composition(sP, tStS_inv) FAILED: {e}") + + # Try with coalesced versions + try: + reindex = cute.composition(sP_coalesced, tStS_inv) + print(f"reindex (coalesced) layout: {reindex}") + except Exception as e2: + print(f"composition(sP_coalesced, tStS_inv) ALSO FAILED: {e2}") + + except Exception as e: + print(f"Composition chain failed: {e}") + + # =========================== + # Step 7: If composition works, try make_cotiled_copy + # =========================== + try: + # atom_layout_tv = composition(reindex, dst_tv) + # This maps (tid, vid) → tStS_addr → sP_addr + atom_layout_tv = cute.composition(reindex, dst_tv) + print(f"\natom_layout_tv: {atom_layout_tv}") + print(f"atom_layout_tv shape: {cute.shape(atom_layout_tv)}") + + # Try make_cotiled_copy + r2s_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + BFloat16, + num_bits_per_copy=16, + ) + tiled_r2s = cute.make_cotiled_copy(r2s_atom, atom_layout_tv, sP_coalesced) + print(f"\nmake_cotiled_copy SUCCEEDED!") + print(f"tiled_r2s layout_tv_tiled: {tiled_r2s.layout_tv_tiled}") + print(f"tiled_r2s layout_dst_tv_tiled: {tiled_r2s.layout_dst_tv_tiled}") + + # Try partition + thr_r2s = tiled_r2s.get_slice(0) + print(f"partition_D and partition_S work!") + except Exception as e: + print(f"\nmake_cotiled_copy FAILED: {e}") + import traceback + traceback.print_exc() + + # =========================== + # Alternative: Try make_tiled_copy_tv + # =========================== + # We know the softmax threads (128 total) and each thread has 128 values + # thr_layout: (128, 1) → tid (or some other layout matching the TMEM-load) + # val_layout: (32, 4) → vid (from the coordinate shape ((32,1),4)) + + # The TMEM-load's TV layout tells us the thread and value structure + print(f"\n=== TV layout analysis ===") + tv = dst_tv + # If tv has shape (128, 128), that's 128 threads × 128 values + # The thread layout maps (tid_group) → tid, and val layout maps (vid_group) → vid + # But we need to decompose the TV layout into separate thread and value parts + + # For make_tiled_copy_tv, we need: + # thr_layout: mapping from tile coordinates to thread IDs + # val_layout: mapping from value coordinates to value IDs + # These should be "compact" (i.e., the layout produces consecutive thread/value IDs) + + # The TV layout for our TMEM load likely has: + # 128 threads (4 warps × 32 threads) + # 128 values per thread (32×4 from coordinate shape) + + # If the TV layout has shape ((A, B), (C, D)) with some nesting, + # we need to figure out which dimensions are "thread" and which are "value" + + print(f"Full TV layout: {tv}") + print(f"TV shape raw: {tv.shape}") + print(f"TV stride raw: {tv.stride}") + + +if __name__ == '__main__': + main()