diff --git a/tests/unit/test_d1_3_cotiled.py b/tests/unit/test_d1_3_cotiled.py new file mode 100644 index 00000000..56427544 --- /dev/null +++ b/tests/unit/test_d1_3_cotiled.py @@ -0,0 +1,297 @@ +""" +D1.3 SMEM-P: Diagnostic for make_cotiled_copy approach. + +Goal: Build a custom R→S tiled copy that maps softmax thread registers +(TMEM-load ownership) to sP (PV A-operand SMEM with swizzle). + +Steps: +1. Print tiled_tmem_load TV layout shapes +2. Print tTMEM_LOADcS coordinate partition +3. Build atom_layout_tv: (tid, vid) -> sP address +4. Create make_cotiled_copy and print partition shapes +5. Test: write P to sP via cotiled copy, read back via PV MMA, verify +""" +import torch, math +import cutlass, cutlass.cute as cute, cutlass.utils as utils +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cutlass.torch as ct +import cuda.bindings.driver as cuda + + +def test_cotiled_copy_diag(): + """Print TV layout shapes from TMEM load and sP to understand the mapping.""" + print("=== make_cotiled_copy Diagnostic ===\n") + + head_dim = 64 # Start with proven hd=64 (TMEM-P works at cos 0.973) + s_k = 128 + pv_n_tile = min(head_dim, 256) + qk_mma_tiler = (128, 128, 128 * 4) + pv_mma_tiler = (128, pv_n_tile, 128) + + # Build MMA objects + a_major = LayoutEnum.ROW_MAJOR + b_major = LayoutEnum.ROW_MAJOR + v_major = LayoutEnum.ROW_MAJOR + + qk_mma = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, a_major, b_major, Float32, + tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM + ) + pv_mma = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, a_major, v_major, Float32, + tcgen05.CtaGroup.ONE, (128, pv_n_tile), tcgen05.OperandSource.SMEM + ) + + # Build SMEM layouts + p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) + + # Build QK C-fragment and TMEM load partition + qk_thr = qk_mma.get_slice(0) + qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + + # TMEM load atoms + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS) + + # Print the TV layout of the TMEM load + print(f"tiled_tmem_load shape: {cute.shape(tiled_tmem_load)}") + tv_layout = tiled_tmem_load.layout_dst_tv_tiled + print(f"TV layout: {tv_layout}") + print(f"TV layout shape: {cute.shape(tv_layout)}") + + # Print per-thread coordinate partition + sfw_idx = 0 # first softmax thread + thr_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS) + print(f"\ntTMEM_LOADtS shape (thread 0): {cute.shape(tTMEM_LOADtS)}") + print(f"tTMEM_LOADtS layout: {tTMEM_LOADtS.layout}") + + cS = cute.make_identity_tensor((qk_mma_tiler[0], qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tTMEM_LOADcS = thr_load.partition_D(tScS) + print(f"tTMEM_LOADcS shape (thread 0): {cute.shape(tTMEM_LOADcS)}") + print(f"tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}") + + # Print sP layout + sP_shape = p_smem_s.outer + print(f"\nsP outer shape: {cute.shape(sP_shape)}") + print(f"sP outer layout: {sP_shape}") + print(f"sP inner (swizzle): {p_smem_s.inner}") + + # Print what each softmax thread's coordinates look like + # For 128 threads, each with ((32,1),4,1,1) = 128 elements + print(f"\n=== Coordinate analysis ===") + print(f"Total softmax threads: 128 (warps 0-3)") + print(f"Per thread: {cute.size(tTMEM_LOADcS)} coordinate pairs") + print(f"Total coordinates: 128 * {cute.size(tTMEM_LOADcS)} = {128 * cute.size(tTMEM_LOADcS)}") + print(f"Expected: 128*128 = {128*128}") + + # The key question: can we build an atom_layout_tv that maps + # (tid, vid) -> sP address? + # + # For make_cotiled_copy: + # atom_layout_tv: (tid, vid) -> data address in sP's codomain + # data_layout: data coord -> data address (this is sP's layout) + # + # The TMEM load's TV layout maps (tid, vid) -> tStS0 coordinates. + # tStS0 is a TMEM tensor with layout ((128,128),1,1):((65536,1),0,0) + # So (tid, vid) -> flat TMEM address. + # But we need (tid, vid) -> flat sP address. + # + # The connection: tStS0 stores S with logical (m, k) = (128, 128). + # sP stores P with the same logical (m, k) but in SMEM layout. + # So the mapping is: + # (tid, vid) -> tStS0 address -> (m, k) via inverse of tStS0.layout + # (m, k) -> sP address via sP.layout + # + # But computing the inverse of tStS0 layout at Python time is the challenge. + # Let's try a different approach: build the atom_layout_tv directly. + + # The TMEM load has 128 threads and 128 values per thread. + # Total values = 128 * 128 = 16384 = 128 * 128 ✓ + # + # We need: atom_layout_tv such that for thread tid and value vid, + # the output is the flat address in sP's layout. + # + # The TMEM load's layout_dst_tv_tiled already maps (tid, vid) -> tStS0 flat addr. + # But tStS0 flat addr = m * 65536 + k (from layout ((128,128):((65536,1))) + # So we can extract (m, k) from the address: m = addr // 65536, k = addr % 65536 + # Wait, that's not right. The layout is ((128,128),1,1):((65536,1),0,0) + # So the address for coordinate (m, k) is m * 65536 + k * 1 + # Meaning: addr = m * 65536 + k + # So: m = addr // 65536, k = addr % 1... no, stride of k is 1. + # Actually: addr = m * 65536 + k. So m = addr // 65536, k = addr % 65536 + # But k ranges from 0 to 127, so k = addr % 65536 (which should be < 128). + # And m ranges from 0 to 127, so m = addr // 65536 (which should be < 128). + # Wait, that doesn't work because addr could be huge. + # Let me reconsider. + + # tStS layout: ((128,128),1,1):((65536,1),0,0) + # This is a 3-mode tensor. For coordinate ((m, k), i, j): + # addr = m * 65536 + k * 1 + i * 0 + j * 0 + # So effectively: addr = m * 65536 + k + # + # sP layout (from p_smem_s.outer): ((128,16),1,(4,2),1):(((64,1),0,((16,8192),0) + # For coordinate ((m, k0), i, (k1, k2), j): + # addr = m * 64 + k0 * 1 + i * 0 + k1 * 16 + k2 * 8192 + j * 0 + # addr = m * 64 + k0 + k1 * 16 + k2 * 8192 + # + # Given (m, k) from TMEM load coords: + # k0 = k % 16, k1 = (k // 16) % 4, k2 = k // 64 + # sP_addr = m * 64 + (k % 16) + ((k // 16) % 4) * 16 + (k // 64) * 8192 + # + # But we need to account for the swizzle! The swizzle is applied by + # CuTe automatically when you index into sP. So the sP.layout already + # includes the swizzle in the address computation. + # + # Actually, the swizzle is in p_smem_s.inner, not in p_smem_s.outer. + # The outer layout is the logical layout (no swizzle). + # When we allocate sP with swizzle=p_smem_s.inner, the indexing + # automatically applies the swizzle XOR. + # + # So for make_cotiled_copy, data_layout should be the COMPOSED layout + # (outer with swizzle applied). Let me check how CUTLASS handles this. + # + # Actually, in CuTe, when you do cute.make_tensor(ptr, layout, swizzle), + # the swizzle is applied during tensor creation. The layout is the + # pre-swizzle layout, and the swizzle XOR is applied on top. + # + # For make_cotiled_copy, we need to think about what "address" means. + # The atom_layout_tv maps (tid, vid) to data addresses. + # If data_layout includes the swizzle, then the addresses are post-swizzle. + # If not, they're pre-swizzle. + # + # The safest approach: use the 2D sP without the stage dimension, + # and let CuTe handle swizzle via the tensor's layout+swizzle. + # + # Actually, let me look at what make_cotiled_copy expects. + # The docstring says: "atom_layout_tv: (tid, vid) -> data addr" + # and "data_layout: data coord -> data addr" + # So the atom_layout_tv's codomain should match data_layout's codomain. + # + # For sP with swizzle, the "data addr" is the swizzled address. + # So we need to compose: (tid, vid) -> (m, k) -> swizzled sP addr. + # + # This is getting complex. Let me try the simpler approach first: + # use the current coordinate-indexed write but verify it works. + # Then come back and optimize with make_cotiled_copy. + + print("\n=== Attempting make_cotiled_copy ===") + + # Step 1: Build atom_layout_tv from TMEM load's TV layout. + # The TMEM load's layout_dst_tv_tiled maps (tid, vid) -> tStS0 flat address. + # We need to transform this to map (tid, vid) -> sP flat address. + # + # The transformation is: + # tStS0_addr = m * 65536 + k (from tStS0 layout) + # sP_addr = m * 64 + (k % 16) + ((k // 16) % 4) * 16 + (k // 64) * 8192 + # + # This is NOT a simple layout composition because the (m, k) decomposition + # from tStS0_addr is not trivial (stride 65536 for m). + # + # Alternative: Use make_cotiled_copy with the TMEM load's TV layout + # but change the data_layout to something that maps tStS0 addresses to + # the same codomain as sP addresses. + # + # Actually, the cleanest approach is to build atom_layout_tv from scratch + # using the coordinate information from tTMEM_LOADcS. + # + # Each softmax thread owns 128 (m, k) pairs. + # We can enumerate all 128 * 128 = 16384 (tid, (m, k)) pairs + # and compute the sP address for each. + + # But in CuTeDSL, we can't iterate and build layouts at Python time + # inside @cute.jit. We need to build the layout at Python (trace) time. + + # Let me try a different approach: use make_tiled_copy_tv. + # thr_layout: maps (TileM, TileN) -> tid + # val_layout: maps (ValueM, ValueN) -> vid + # + # The softmax threads are indexed 0..127 (4 warps × 32 threads). + # Each thread owns a (32, 4) sub-tile of the P matrix (from tTMEM_LOADcS shape ((32,1),4)). + # So thr_layout should map (32, 4) tiles to 128 threads. + # And val_layout should map (32, 1) values per tile position. + # + # Wait, the TMEM load's coordinate partition is ((32,1),4,1,1). + # This means each thread has 32 × 4 = 128 coordinate pairs. + # The first mode (32,1) is the "row" within a fragment. + # The second mode 4 is the "fragment" index. + # + # For make_tiled_copy_tv: + # - thr_layout: how threads tile the (M, K) = (128, 128) P matrix + # - val_layout: how values are arranged within a thread's tile + # + # From the TMEM load partition, each thread owns: + # - 32 M-values (not necessarily contiguous) + # - 4 K-fragments + # The thread layout is determined by the Ld32x32bOp atom's thread mapping. + + # Let me just print the actual TV layout to understand it. + print(f"tiled_tmem_load layout_dst_tv shape: {cute.shape(tv_layout)}") + if hasattr(tv_layout, 'a'): + print(f" a (thread): {tv_layout.a}") + if hasattr(tv_layout, 'b'): + print(f" b (value): {tv_layout.b}") + + # Print sP composed layout (with swizzle) + # The key insight from the CUTLASS LLM: we need atom_layout_tv + # such that atom_layout_tv(tid, vid) gives a coordinate in sP's codomain. + # + # But actually, for make_cotiled_copy, the atom_layout_tv maps to + # "data addr" which is the same codomain as data_layout. + # data_layout maps data coordinates to addresses. + # So atom_layout_tv(tid, vid) should produce an address. + # + # The TV layout from tiled_tmem_load maps (tid, vid) to tStS0 addresses. + # If we could remap tStS0 addresses to sP addresses, we'd have it. + # + # tStS0 addresses: m * 65536 + k (m in [0,128), k in [0,128)) + # sP addresses: m * 64 + (k % 16) + ((k // 16) % 4) * 16 + (k // 64) * 8192 + # + swizzle XOR + # + # The mapping is: tStS0_addr -> (m, k) -> sP_coord -> sP_addr + # tStS0_addr = m * 65536 + k + # m = tStS0_addr // 65536 + # k = tStS0_addr % 65536 (but should be < 128) + # Actually since k < 128 and stride is 1, k = tStS0_addr % 65536 is correct + # and since 128 < 65536, this gives k directly. + # And m = tStS0_addr // 65536 (since k < 65536, integer division works). + + # The problem: computing m = addr // 65536 and k = addr % 65536 + # from a Layout object is not straightforward. Layouts are affine maps. + # Division/modulo by 65536 is not affine in general. + # + # BUT: the TMEM load's TV layout already encodes the (tid, vid) -> addr map + # as a Layout. We need to transform this Layout to produce sP addresses + # instead of tStS0 addresses. + # + # Let me try yet another approach: just print what we have and think. + + print("\n=== Checking if we can extract (m,k) from TV layout ===") + # The TV layout has shape (128_threads, 128_values) mapping to addresses. + # If we can reshape this to (128, 128) and interpret the output as (m, k) + # coordinates, then compose with sP's layout, we get what we need. + + # Actually, let me try the simplest possible thing: + # Print the existing TV layout and see if it's compatible with sP + # in any way. + try: + sP_stage = p_smem_s.outer + # Try to see what the layout composition would look like + print(f"sP_stage layout: {sP_stage}") + print(f"TV layout: {tv_layout}") + # Can we compose tv_layout with the inverse of tStS.layout, then with sP.layout? + # cute.composition(sP_stage, tv_layout) might work if the codomains match + print(f"tStS layout: {tStS.layout}") + except Exception as e: + print(f"Error: {e}") + + +if __name__ == '__main__': + test_cotiled_copy_diag()