diff --git a/tests/unit/test_d1_3_layout_diag.py b/tests/unit/test_d1_3_layout_diag.py index ae8aab3f..eb0e5863 100644 --- a/tests/unit/test_d1_3_layout_diag.py +++ b/tests/unit/test_d1_3_layout_diag.py @@ -6,11 +6,14 @@ PV A-operand SMEM layout (destination) so we can write a proper register→SMEM copy for P values. This test ONLY prints shapes. No kernel execution. Pure layout analysis. +Uses FmhaKernel._setup to compute layouts the same way the kernel does. """ import torch, math import cutlass, cutlass.cute as cute, cutlass.utils as utils -from cutlass.cute.nvgpu import tcgen05 +from cutlass.cute.nvgpu import tcgen05, cpasync +from cutlass import Float32, BFloat16, Int32 from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset import cutlass.torch as ct @@ -22,94 +25,79 @@ def analyze_layouts(hd, s_k=128): m = 128 use_smem_p = hd > 64 pv_n_tile = min(hd, 256) + n_kv_tiles = s_k // 128 + # Create dummy tensors just for layout computation q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') - v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda') + c = torch.randn(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') - # Reconstruct V with FMHA layout - v_fmha = cute.make_tensor( - v.data_ptr(), - cute.make_layout((pv_n_tile, s_k, 1), stride=(1, pv_n_tile, pv_n_tile * s_k)), - ) + a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode() + b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode() + c_major = LayoutEnum.from_tensor(ct.from_dlpack(c)).mma_major_mode() + # Build QK MMA qk_mma = utils.sm100.make_trivial_tiled_mma( - cutlass.BFloat16, cutlass.BFloat16, - LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode(), - LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_mn_mode(), - cutlass.Float32, tcgen05.CtaGroup.ONE, (128, 128), - tcgen05.OperandSource.SMEM, + BFloat16, BFloat16, a_major, b_major, Float32, + tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM, ) - pv_a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode() if use_smem_p else cute.nvgpu.OperandMajorMode.K + # Build PV MMA (SMEM-P for hd>64, TMEM-P for hd<=64) + pv_a_major = a_major if use_smem_p else cute.nvgpu.OperandMajorMode.K pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM pv_mma = utils.sm100.make_trivial_tiled_mma( - cutlass.BFloat16, cutlass.BFloat16, - pv_a_major, - LayoutEnum.from_tensor(ct.from_dlpack(torch.randn(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda'))).mma_major_mode(), - cutlass.Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile), - pv_source, + BFloat16, BFloat16, pv_a_major, c_major, Float32, + tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source, ) - # MMA tiler + # Tiler computation (same as FmhaKernel._setup) qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) qk_mma_tiler = (128, 128, qk_ik * 4) pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik)) print(f"\n--- MMA info ---") - print(f"QK MMA shape_mnk: {qk_mma.shape_mnk}") - print(f"PV MMA shape_mnk: {pv_mma.shape_mnk}") - print(f"QK tiler: {qk_mma_tiler}") - print(f"PV tiler: {pv_mma_tiler}") - print(f"use_smem_p: {use_smem_p}") - print(f"PV A source: {pv_source}") + print(f" QK MMA shape_mnk: {qk_mma.shape_mnk}") + print(f" PV MMA shape_mnk: {pv_mma.shape_mnk}") + print(f" QK tiler: {qk_mma_tiler}") + print(f" PV tiler: {pv_mma_tiler}") + print(f" use_smem_p: {use_smem_p}") + print(f" PV A source: {pv_source}") + print(f" PV A major: {pv_a_major}") - # QK C-fragment (where S and P live in TMEM) + # QK C-fragment (where S lives in TMEM, where softmax reads P from) qk_thr = qk_mma.get_slice(0) qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2]) tStS = qk_thr.make_fragment_C(qk_as) - print(f"\n--- QK C-fragment (tStS) ---") + print(f"\n--- QK C-fragment (tStS = S and P in TMEM) ---") print(f" qk_as (partition shape): {qk_as}") print(f" tStS shape: {cute.shape(tStS)}") + print(f" tStS size: {cute.size(tStS)}") print(f" tStS layout: {tStS.layout}") - # P in TMEM (subset of S columns) - p_cols_fp32 = pv_mma_tiler[2] * 16 // 32 # BF16 width / FP32 width - print(f" p_cols_fp32 (PV MMA tiler[2] * 16/32): {p_cols_fp32}") - p_cols_fp32_v2 = pv_mma_tiler[2] * cutlass.BFloat16.width // cutlass.Float32.width - print(f" p_cols_fp32 (PV MMA tiler[2] * BF16_width/FP32_width): {p_cols_fp32_v2}") - # PV C-fragment (where O lives in TMEM) pv_thr = pv_mma.get_slice(0) pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2]) tOtO = pv_thr.make_fragment_C(pv_as) - print(f"\n--- PV C-fragment (tOtO) ---") + print(f"\n--- PV C-fragment (tOtO = O in TMEM) ---") print(f" pv_as (partition shape): {pv_as}") print(f" tOtO shape: {cute.shape(tOtO)}") + print(f" tOtO size: {cute.size(tOtO)}") print(f" tOtO layout: {tOtO.layout}") - # PV A-operand (where P is read from in PV GEMM) - # TMEM-P: tOrP from TMEM - # SMEM-P: tCrP from SMEM, tOrP from sP - p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, cutlass.BFloat16, 1) - p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, cutlass.BFloat16, 1) - - print(f"\n--- PV A-operand layouts ---") - print(f" p_tmem_s outer: {p_tmem_s.outer}") - print(f" p_tmem_s inner (swizzle): {p_smem_s.inner}") - - # Create sP tensor (SMEM allocation) - # We need to allocate sP with the PV A-operand SMEM layout + # PV A-operand SMEM layout (where P must be written for SMEM-P) + p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) print(f"\n--- PV A-operand SMEM layout (sP) ---") print(f" p_smem_s outer: {p_smem_s.outer}") - print(f" p_smem_s outer shape: {cute.shape(cute.make_layout(p_smem_s.outer))}") - print(f" p_smem_s inner: {p_smem_s.inner}") + print(f" p_smem_s inner (swizzle): {p_smem_s.inner}") + outer_layout = cute.make_layout(p_smem_s.outer) + print(f" p_smem_s outer shape: {cute.shape(outer_layout)}") + print(f" p_smem_s outer size: {cute.size(outer_layout)}") - # What does the MMA A-operand fragment look like for SMEM-P? - # This is what the MMA warp reads from sP - # We need a dummy sP tensor to create the fragment - # Use composition to get the A-operand partition + # PV A-operand fragments (what MMA reads from TMEM or SMEM) + # For TMEM-P: tOrP reads from TMEM + # For SMEM-P: tCrP reads from SMEM, tOrP also reads from SMEM + p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) tP_tmem = cute.make_tensor(tStS.iterator, p_tmem_s.outer) tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer) @@ -118,17 +106,24 @@ def analyze_layouts(hd, s_k=128): tOrP_smem = pv_thr.make_fragment_A(tP_smem) print(f"\n--- PV A-operand fragments ---") - print(f" tOrP_tmem (TMEM-P read): shape={cute.shape(tOrP_tmem)} layout={tOrP_tmem.layout}") - print(f" tCrP_smem (SMEM-P load): shape={cute.shape(tCrP_smem)} layout={tCrP_smem.layout}") - print(f" tOrP_smem (SMEM-P read): shape={cute.shape(tOrP_smem)} layout={tOrP_smem.layout}") + print(f" tOrP_tmem (TMEM-P A-read): shape={cute.shape(tOrP_tmem)}") + print(f" tCrP_smem (SMEM-P A-load): shape={cute.shape(tCrP_smem)} layout={tCrP_smem.layout}") + print(f" tOrP_smem (SMEM-P A-read): shape={cute.shape(tOrP_smem)} layout={tOrP_smem.layout}") + print(f" tOrP_tmem size: {cute.size(tOrP_tmem)}") + print(f" tCrP_smem size: {cute.size(tCrP_smem)}") - # The softmax warps' view of P - # They have P in QK C-fragment layout (tStS), and they need to write to sP - # Let's understand the QK C-fragment partition per-thread - sfw_idx = 0 # thread 0 in softmax warps + # Slice tOrP the way the kernel does: [(None,None,None,0)] + tOrP_base_tmem = pv_thr.make_fragment_A(tP_tmem) + tOrP_tmem_sliced = tOrP_base_tmem[(None,None,None,0)] + tOrP_base_smem = pv_thr.make_fragment_A(tP_smem) + tOrP_smem_sliced = tOrP_base_smem[(None,None,None,0)] + print(f"\n--- PV A-operand after slice [(None,None,None,0)] ---") + print(f" tOrP_tmem_sliced: shape={cute.shape(tOrP_tmem_sliced)}") + print(f" tOrP_smem_sliced: shape={cute.shape(tOrP_smem_sliced)}") - # TMEM load/store atoms for S/P - tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32) + # Softmax thread partition of S (via TMEM load atoms) + sfw_idx = 0 + tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32) tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS) thr_load = tiled_tmem_load.get_slice(sfw_idx) tTMEM_LOADtS = thr_load.partition_S(tStS) @@ -136,98 +131,109 @@ def analyze_layouts(hd, s_k=128): tScS = qk_thr.partition_C(cS) tTMEM_LOADcS = thr_load.partition_D(tScS) - print(f"\n--- Softmax thread 0 TMEM load partition ---") + print(f"\n--- Softmax thread 0: TMEM load partition ---") print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}") print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}") print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}") - print(f" => Each softmax thread has {cute.size(tTMEM_LOADcS)} S coordinates") + print(f" => Each softmax thread has {cute.size(tTMEM_LOADcS)} S/P coordinates") - # Now try make_tiled_copy_C with qk_mma - print(f"\n--- make_tiled_copy_C(qk_mma) attempt ---") + # P register bridge: what P values does each softmax thread have? + # After softmax, P is in rP_bf16 which shares layout with rS + # rP_bf16 = make_tensor(recast_ptr(rP_words.iterator, dtype=BF16), tTMEM_LOADrS.layout) + # where tTMEM_LOADrS = make_rmem_tensor(tTMEM_LOADcS.shape, FP32) + p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width + tStP_layout = cute.composition(tStS.layout, cute.make_layout((pv_mma_tiler[0], p_cols_fp32))) + print(f"\n--- P columns in TMEM ---") + print(f" p_cols_fp32 (PV tiler[2] * BF16/FP32): {p_cols_fp32}") + print(f" P TMEM sub-layout (composition of tStS): {tStP_layout}") + print(f" P TMEM sub-layout shape: {cute.shape(cute.make_layout(tStP_layout))}") + + # Now the critical question: what's the sP shape that the softmax needs to write to? + sP_alloc = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer) + sP_2d = cute.group_modes(sP_alloc, 0, 3) + print(f"\n--- sP for SMEM-P write target ---") + print(f" sP shape (3D): {cute.shape(sP_alloc)}") + print(f" sP_2d shape (grouped): {cute.shape(sP_2d)}") + print(f" sP_2d rank: {len(cute.shape(sP_2d))}") + + # Try make_tiled_copy_C with QK MMA (the approach that gave rank mismatch) + print(f"\n--- make_tiled_copy_C(smem_copy_atom, qk_mma) ---") try: smem_copy_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), - cutlass.BFloat16, + BFloat16, num_bits_per_copy=128, ) - tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) - thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx) - sP_2d = cute.group_modes(cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer), 0, 3) - tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) - print(f" tiled_smem_copy: created OK") - print(f" sP_2d shape: {cute.shape(sP_2d)} rank: {len(cute.shape(sP_2d))}") - print(f" tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}") + tiled_copy_qk = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) + thr_qk = tiled_copy_qk.get_slice(sfw_idx) + tSMEM_QKsP = thr_qk.partition_D(sP_2d) + print(f" partition_D(sP_2d) shape: {cute.shape(tSMEM_QKsP)} rank={len(cute.shape(tSMEM_QKsP))}") - # Try partition_S with the rP register layout - # rP is in the same layout as rS (TMEM load partition) - rP_layout = tTMEM_LOADcS.layout # coordinate layout - rP_bf16 = cute.make_tensor(cute.find(rP_layout, 0), rP_layout) + # Try partition_S with various source layouts + # 1. Identity tensor (coordinate) layout try: - tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16) - print(f" tSMEM_CPYrP (from coord layout) shape: {cute.shape(tSMEM_CPYrP)} rank: {len(cute.shape(tSMEM_CPYrP))}") + rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout) + tSMEM_QKrP = thr_qk.partition_S(rP_coord) + print(f" partition_S(coord layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}") except Exception as e: - print(f" partition_S with coord layout FAILED: {e}") + print(f" partition_S(coord layout) FAILED: {e}") - # Try with tStS layout (QK C-fragment) + # 2. QK C-fragment layout (tStS) try: rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout) - tSMEM_CPYrP = thr_smem_copy.partition_S(rP_qk) - print(f" tSMEM_CPYrP (from QK C-frag layout) shape: {cute.shape(tSMEM_CPYrP)} rank: {len(cute.shape(tSMEM_CPYrP))}") + tSMEM_QKrP = thr_qk.partition_S(rP_qk) + print(f" partition_S(QK C-frag layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}") except Exception as e: - print(f" partition_S with QK C-frag layout FAILED: {e}") + print(f" partition_S(QK C-frag layout) FAILED: {e}") + + # 3. The rmem layout that softmax actually has P in + rP_rmem_layout = cute.make_layout(tTMEM_LOADcS.shape) + try: + rP_rmem = cute.make_tensor(cute.find(rP_rmem_layout, 0), rP_rmem_layout) + tSMEM_QKrP = thr_qk.partition_S(rP_rmem) + print(f" partition_S(rmem layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}") + except Exception as e: + print(f" partition_S(rmem layout) FAILED: {e}") except Exception as e: - print(f" make_tiled_copy_C FAILED: {e}") + print(f" make_tiled_copy_C(qk_mma) FAILED: {e}") - # Now try make_tiled_copy (NOT make_tiled_copy_C) with pv_mma - print(f"\n--- make_tiled_copy with PV A-operand SMEM layout ---") + # Try make_tiled_copy_C with PV MMA (different approach) + print(f"\n--- make_tiled_copy_C(smem_copy_atom, pv_mma) ---") try: - # Use the PV MMA's A-operand thread mapping - # make_tiled_copy uses an M/N layout to determine thread partitioning - store_atom = cute.make_copy_atom( + smem_copy_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), - cutlass.BFloat16, + BFloat16, num_bits_per_copy=128, ) - # Create a tiled copy that uses the QK MMA's C-fragment layout for thread partitioning - # but writes to the PV A-operand SMEM layout - tiled_copy_pv = cute.make_tiled_copy_C(store_atom, pv_mma) - print(f" make_tiled_copy_C(pv_mma) created OK") + tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma) thr_pv = tiled_copy_pv.get_slice(sfw_idx) - sP_2d = cute.group_modes(cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer), 0, 3) tSMEM_PVsP = thr_pv.partition_D(sP_2d) - print(f" tSMEM_PVsP shape: {cute.shape(tSMEM_PVsP)} rank: {len(cute.shape(tSMEM_PVsP))}") + print(f" partition_D(sP_2d) shape: {cute.shape(tSMEM_PVsP)} rank={len(cute.shape(tSMEM_PVsP))}") except Exception as e: print(f" make_tiled_copy_C(pv_mma) FAILED: {e}") - # Print the full layout details for manual analysis - print(f"\n--- Detailed layout info ---") - print(f" p_smem_s.outer type: {type(p_smem_s.outer)}") - print(f" p_smem_s.outer: {p_smem_s.outer}") - if hasattr(p_smem_s.outer, 'shape'): - print(f" p_smem_s.outer.shape: {p_smem_s.outer.shape}") - if hasattr(p_smem_s.outer, 'stride'): - print(f" p_smem_s.outer.stride: {p_smem_s.outer.stride}") - - # Understand the P matrix dimensions - # P is (128, n_kv) in the attention computation - # But PV MMA sees P as (128, pv_n_tile) A-operand - # For hd=64: P is (128, 64) with n_kv=128 KV positions - # Actually P is (M, K_PV) where K_PV = n_kv for each head - # Wait - PV is P[M, n_kv] @ V[n_kv, hd] - # So P's second dimension is n_kv (= 128 for s_k=128) - print(f"\n--- P matrix dimensions ---") - print(f" P is (128, {s_k}) attention weights matrix") - print(f" PV GEMM: P[128, {s_k}] @ V[{s_k}, {pv_n_tile}] = O[128, {pv_n_tile}]") - print(f" PV A-operand shape: (128, {s_k}) = QK MMA's K dimension") - print(f" But PV MMA tiler K dimension: {pv_mma_tiler[2]}") - print(f" So PV A-operand is (128, {pv_mma_tiler[2]}) per PV sub-tile") - - # Key insight: P's dimensions are (128, n_kv) in the QK MMA's K dimension - # The QK C-fragment has shape ((128, 128), 1, 1, 1) for S - # The PV A-operand SMEM has shape that depends on PV MMA tiling - # These represent the SAME data (128 rows, 128 KV positions) - # But they're partitioned differently by the two MMAs + # Try make_tiled_copy (not _C) with explicit M/N layout + print(f"\n--- make_tiled_copy (manual M/N) ---") + try: + # The P matrix is (128, n_kv) = (M=128, N_P=128) + # Use (128, 128) tile for the copy + MN = cute.make_layout((128, 128)) + smem_copy_atom2 = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + BFloat16, + num_bits_per_copy=128, + ) + # Need to use make_tiled_copy with proper thread arrangement + # This is for the softmax warps (128 threads = 4 warps) + thr_layout = cute.make_layout((4, 32)) # 4 warps × 32 lanes + val_layout = cute.make_layout((1, 1)) # each thread copies 1 element + tiled_manual = cute.make_tiled_copy(smem_copy_atom2, thr_layout, val_layout) + thr_manual = tiled_manual.get_slice(sfw_idx) + tMANUALsP = thr_manual.partition_D(sP_2d) + print(f" partition_D(sP_2d) shape: {cute.shape(tMANUALsP)} rank={len(cute.shape(tMANUALsP))}") + except Exception as e: + print(f" make_tiled_copy (manual) FAILED: {e}") if __name__ == '__main__':