diff --git a/tests/unit/test_d1_3_layout_diag.py b/tests/unit/test_d1_3_layout_diag.py new file mode 100644 index 00000000..ae8aab3f --- /dev/null +++ b/tests/unit/test_d1_3_layout_diag.py @@ -0,0 +1,235 @@ +""" +D1.3 SMEM-P diagnostic: print ALL relevant shapes and layouts. + +Goal: Understand the exact QK C-fragment partition (source) and +PV A-operand SMEM layout (destination) so we can write a proper +register→SMEM copy for P values. + +This test ONLY prints shapes. No kernel execution. Pure layout analysis. +""" +import torch, math +import cutlass, cutlass.cute as cute, cutlass.utils as utils +from cutlass.cute.nvgpu import tcgen05 +from cutlass.utils import LayoutEnum +import cutlass.torch as ct + + +def analyze_layouts(hd, s_k=128): + print(f"\n{'='*60}") + print(f" HEAD_DIM={hd}, s_k={s_k}") + print(f"{'='*60}") + + m = 128 + use_smem_p = hd > 64 + pv_n_tile = min(hd, 256) + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') + + # Reconstruct V with FMHA layout + v_fmha = cute.make_tensor( + v.data_ptr(), + cute.make_layout((pv_n_tile, s_k, 1), stride=(1, pv_n_tile, pv_n_tile * s_k)), + ) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + cutlass.BFloat16, cutlass.BFloat16, + LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode(), + LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_mn_mode(), + cutlass.Float32, tcgen05.CtaGroup.ONE, (128, 128), + tcgen05.OperandSource.SMEM, + ) + + pv_a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode() if use_smem_p else cute.nvgpu.OperandMajorMode.K + pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM + pv_mma = utils.sm100.make_trivial_tiled_mma( + cutlass.BFloat16, cutlass.BFloat16, + pv_a_major, + LayoutEnum.from_tensor(ct.from_dlpack(torch.randn(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda'))).mma_major_mode(), + cutlass.Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile), + pv_source, + ) + + # MMA tiler + qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) + qk_mma_tiler = (128, 128, qk_ik * 4) + pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) + pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik)) + + print(f"\n--- MMA info ---") + print(f"QK MMA shape_mnk: {qk_mma.shape_mnk}") + print(f"PV MMA shape_mnk: {pv_mma.shape_mnk}") + print(f"QK tiler: {qk_mma_tiler}") + print(f"PV tiler: {pv_mma_tiler}") + print(f"use_smem_p: {use_smem_p}") + print(f"PV A source: {pv_source}") + + # QK C-fragment (where S and P live in TMEM) + qk_thr = qk_mma.get_slice(0) + qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + print(f"\n--- QK C-fragment (tStS) ---") + print(f" qk_as (partition shape): {qk_as}") + print(f" tStS shape: {cute.shape(tStS)}") + print(f" tStS layout: {tStS.layout}") + + # P in TMEM (subset of S columns) + p_cols_fp32 = pv_mma_tiler[2] * 16 // 32 # BF16 width / FP32 width + print(f" p_cols_fp32 (PV MMA tiler[2] * 16/32): {p_cols_fp32}") + p_cols_fp32_v2 = pv_mma_tiler[2] * cutlass.BFloat16.width // cutlass.Float32.width + print(f" p_cols_fp32 (PV MMA tiler[2] * BF16_width/FP32_width): {p_cols_fp32_v2}") + + # PV C-fragment (where O lives in TMEM) + pv_thr = pv_mma.get_slice(0) + pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + print(f"\n--- PV C-fragment (tOtO) ---") + print(f" pv_as (partition shape): {pv_as}") + print(f" tOtO shape: {cute.shape(tOtO)}") + print(f" tOtO layout: {tOtO.layout}") + + # PV A-operand (where P is read from in PV GEMM) + # TMEM-P: tOrP from TMEM + # SMEM-P: tCrP from SMEM, tOrP from sP + p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, cutlass.BFloat16, 1) + p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, cutlass.BFloat16, 1) + + print(f"\n--- PV A-operand layouts ---") + print(f" p_tmem_s outer: {p_tmem_s.outer}") + print(f" p_tmem_s inner (swizzle): {p_smem_s.inner}") + + # Create sP tensor (SMEM allocation) + # We need to allocate sP with the PV A-operand SMEM layout + print(f"\n--- PV A-operand SMEM layout (sP) ---") + print(f" p_smem_s outer: {p_smem_s.outer}") + print(f" p_smem_s outer shape: {cute.shape(cute.make_layout(p_smem_s.outer))}") + print(f" p_smem_s inner: {p_smem_s.inner}") + + # What does the MMA A-operand fragment look like for SMEM-P? + # This is what the MMA warp reads from sP + # We need a dummy sP tensor to create the fragment + # Use composition to get the A-operand partition + tP_tmem = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer) + + tOrP_tmem = pv_thr.make_fragment_A(tP_tmem) + tCrP_smem = pv_mma.make_fragment_A(tP_smem) + tOrP_smem = pv_thr.make_fragment_A(tP_smem) + + print(f"\n--- PV A-operand fragments ---") + print(f" tOrP_tmem (TMEM-P read): shape={cute.shape(tOrP_tmem)} layout={tOrP_tmem.layout}") + print(f" tCrP_smem (SMEM-P load): shape={cute.shape(tCrP_smem)} layout={tCrP_smem.layout}") + print(f" tOrP_smem (SMEM-P read): shape={cute.shape(tOrP_smem)} layout={tOrP_smem.layout}") + + # The softmax warps' view of P + # They have P in QK C-fragment layout (tStS), and they need to write to sP + # Let's understand the QK C-fragment partition per-thread + sfw_idx = 0 # thread 0 in softmax warps + + # TMEM load/store atoms for S/P + tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS) + thr_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS) + cS = cute.make_identity_tensor((128, 128)) + tScS = qk_thr.partition_C(cS) + tTMEM_LOADcS = thr_load.partition_D(tScS) + + print(f"\n--- Softmax thread 0 TMEM load partition ---") + print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}") + print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}") + print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}") + print(f" => Each softmax thread has {cute.size(tTMEM_LOADcS)} S coordinates") + + # Now try make_tiled_copy_C with qk_mma + print(f"\n--- make_tiled_copy_C(qk_mma) attempt ---") + try: + smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.BFloat16, + num_bits_per_copy=128, + ) + tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) + thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx) + sP_2d = cute.group_modes(cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer), 0, 3) + tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) + print(f" tiled_smem_copy: created OK") + print(f" sP_2d shape: {cute.shape(sP_2d)} rank: {len(cute.shape(sP_2d))}") + print(f" tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}") + + # Try partition_S with the rP register layout + # rP is in the same layout as rS (TMEM load partition) + rP_layout = tTMEM_LOADcS.layout # coordinate layout + rP_bf16 = cute.make_tensor(cute.find(rP_layout, 0), rP_layout) + try: + tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16) + print(f" tSMEM_CPYrP (from coord layout) shape: {cute.shape(tSMEM_CPYrP)} rank: {len(cute.shape(tSMEM_CPYrP))}") + except Exception as e: + print(f" partition_S with coord layout FAILED: {e}") + + # Try with tStS layout (QK C-fragment) + try: + rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout) + tSMEM_CPYrP = thr_smem_copy.partition_S(rP_qk) + print(f" tSMEM_CPYrP (from QK C-frag layout) shape: {cute.shape(tSMEM_CPYrP)} rank: {len(cute.shape(tSMEM_CPYrP))}") + except Exception as e: + print(f" partition_S with QK C-frag layout FAILED: {e}") + + except Exception as e: + print(f" make_tiled_copy_C FAILED: {e}") + + # Now try make_tiled_copy (NOT make_tiled_copy_C) with pv_mma + print(f"\n--- make_tiled_copy with PV A-operand SMEM layout ---") + try: + # Use the PV MMA's A-operand thread mapping + # make_tiled_copy uses an M/N layout to determine thread partitioning + store_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.BFloat16, + num_bits_per_copy=128, + ) + # Create a tiled copy that uses the QK MMA's C-fragment layout for thread partitioning + # but writes to the PV A-operand SMEM layout + tiled_copy_pv = cute.make_tiled_copy_C(store_atom, pv_mma) + print(f" make_tiled_copy_C(pv_mma) created OK") + thr_pv = tiled_copy_pv.get_slice(sfw_idx) + sP_2d = cute.group_modes(cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer), 0, 3) + tSMEM_PVsP = thr_pv.partition_D(sP_2d) + print(f" tSMEM_PVsP shape: {cute.shape(tSMEM_PVsP)} rank: {len(cute.shape(tSMEM_PVsP))}") + except Exception as e: + print(f" make_tiled_copy_C(pv_mma) FAILED: {e}") + + # Print the full layout details for manual analysis + print(f"\n--- Detailed layout info ---") + print(f" p_smem_s.outer type: {type(p_smem_s.outer)}") + print(f" p_smem_s.outer: {p_smem_s.outer}") + if hasattr(p_smem_s.outer, 'shape'): + print(f" p_smem_s.outer.shape: {p_smem_s.outer.shape}") + if hasattr(p_smem_s.outer, 'stride'): + print(f" p_smem_s.outer.stride: {p_smem_s.outer.stride}") + + # Understand the P matrix dimensions + # P is (128, n_kv) in the attention computation + # But PV MMA sees P as (128, pv_n_tile) A-operand + # For hd=64: P is (128, 64) with n_kv=128 KV positions + # Actually P is (M, K_PV) where K_PV = n_kv for each head + # Wait - PV is P[M, n_kv] @ V[n_kv, hd] + # So P's second dimension is n_kv (= 128 for s_k=128) + print(f"\n--- P matrix dimensions ---") + print(f" P is (128, {s_k}) attention weights matrix") + print(f" PV GEMM: P[128, {s_k}] @ V[{s_k}, {pv_n_tile}] = O[128, {pv_n_tile}]") + print(f" PV A-operand shape: (128, {s_k}) = QK MMA's K dimension") + print(f" But PV MMA tiler K dimension: {pv_mma_tiler[2]}") + print(f" So PV A-operand is (128, {pv_mma_tiler[2]}) per PV sub-tile") + + # Key insight: P's dimensions are (128, n_kv) in the QK MMA's K dimension + # The QK C-fragment has shape ((128, 128), 1, 1, 1) for S + # The PV A-operand SMEM has shape that depends on PV MMA tiling + # These represent the SAME data (128 rows, 128 KV positions) + # But they're partitioned differently by the two MMAs + + +if __name__ == '__main__': + analyze_layouts(64) # TMEM-P (works) + analyze_layouts(256) # SMEM-P (broken)