diff --git a/tests/unit/test_d1_3_layout_diag.py b/tests/unit/test_d1_3_layout_diag.py index eb0e5863..b1290214 100644 --- a/tests/unit/test_d1_3_layout_diag.py +++ b/tests/unit/test_d1_3_layout_diag.py @@ -1,241 +1,182 @@ """ D1.3 SMEM-P diagnostic: print ALL relevant shapes and layouts. - -Goal: Understand the exact QK C-fragment partition (source) and -PV A-operand SMEM layout (destination) so we can write a proper -register→SMEM copy for P values. - -This test ONLY prints shapes. No kernel execution. Pure layout analysis. -Uses FmhaKernel._setup to compute layouts the same way the kernel does. +Runs layout analysis inside a JIT-compiled function to get MLIR context. """ -import torch, math +import torch, math, sys import cutlass, cutlass.cute as cute, cutlass.utils as utils from cutlass.cute.nvgpu import tcgen05, cpasync -from cutlass import Float32, BFloat16, Int32 +from cutlass import Float32, BFloat16, Int32, const_expr from cutlass.utils import LayoutEnum from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset import cutlass.torch as ct -def analyze_layouts(hd, s_k=128): - print(f"\n{'='*60}") - print(f" HEAD_DIM={hd}, s_k={s_k}") - print(f"{'='*60}") +def main(): + for hd in [64, 256]: + print(f"\n{'='*60}") + print(f" HEAD_DIM={hd}") + print(f"{'='*60}") - m = 128 - use_smem_p = hd > 64 - pv_n_tile = min(hd, 256) - n_kv_tiles = s_k // 128 + m = 128; s_k = 128 + use_smem_p = hd > 64 + pv_n_tile = min(hd, 256) - # Create dummy tensors just for layout computation - q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') - k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') - c = torch.randn(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') - a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode() - b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode() - c_major = LayoutEnum.from_tensor(ct.from_dlpack(c)).mma_major_mode() + # Use FmhaKernel's actual __call__ path to compute layouts + from dsv4.kernels.attention.fmha import FmhaKernel + kern = FmhaKernel(head_dim=hd, s_k=s_k) - # Build QK MMA - qk_mma = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, a_major, b_major, Float32, - tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM, - ) + # Trigger _setup by calling compile with a dummy kernel + # Actually, we can access the _setup by calling __call__ on the kernel + # But that runs the full kernel. Let me manually call _setup by + # building the MMAs the same way __call__ does. - # Build PV MMA (SMEM-P for hd>64, TMEM-P for hd<=64) - pv_a_major = a_major if use_smem_p else cute.nvgpu.OperandMajorMode.K - pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM - pv_mma = utils.sm100.make_trivial_tiled_mma( - BFloat16, BFloat16, pv_a_major, c_major, Float32, - tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source, - ) + a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode() + b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode() - # Tiler computation (same as FmhaKernel._setup) - qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) - qk_mma_tiler = (128, 128, qk_ik * 4) - pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) - pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik)) + # Create a minimal analysis kernel that just prints shapes + @cute.kernel + def _diag_kernel(_q, _k, _v, _c): + # Build QK MMA + qk_mma = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, a_major, b_major, Float32, + tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM, + ) - print(f"\n--- MMA info ---") - print(f" QK MMA shape_mnk: {qk_mma.shape_mnk}") - print(f" PV MMA shape_mnk: {pv_mma.shape_mnk}") - print(f" QK tiler: {qk_mma_tiler}") - print(f" PV tiler: {pv_mma_tiler}") - print(f" use_smem_p: {use_smem_p}") - print(f" PV A source: {pv_source}") - print(f" PV A major: {pv_a_major}") + pv_a_major = a_major if use_smem_p else cute.nvgpu.OperandMajorMode.K + pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM + c_major = LayoutEnum.from_tensor(ct.from_dlpack(_c)).mma_major_mode() - # QK C-fragment (where S lives in TMEM, where softmax reads P from) - qk_thr = qk_mma.get_slice(0) - qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2]) - tStS = qk_thr.make_fragment_C(qk_as) - print(f"\n--- QK C-fragment (tStS = S and P in TMEM) ---") - print(f" qk_as (partition shape): {qk_as}") - print(f" tStS shape: {cute.shape(tStS)}") - print(f" tStS size: {cute.size(tStS)}") - print(f" tStS layout: {tStS.layout}") + pv_mma = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, pv_a_major, c_major, Float32, + tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source, + ) - # PV C-fragment (where O lives in TMEM) - pv_thr = pv_mma.get_slice(0) - pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2]) - tOtO = pv_thr.make_fragment_C(pv_as) - print(f"\n--- PV C-fragment (tOtO = O in TMEM) ---") - print(f" pv_as (partition shape): {pv_as}") - print(f" tOtO shape: {cute.shape(tOtO)}") - print(f" tOtO size: {cute.size(tOtO)}") - print(f" tOtO layout: {tOtO.layout}") + qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) + qk_mma_tiler = (128, 128, qk_ik * 4) + pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) + pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik)) - # PV A-operand SMEM layout (where P must be written for SMEM-P) - p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) - print(f"\n--- PV A-operand SMEM layout (sP) ---") - print(f" p_smem_s outer: {p_smem_s.outer}") - print(f" p_smem_s inner (swizzle): {p_smem_s.inner}") - outer_layout = cute.make_layout(p_smem_s.outer) - print(f" p_smem_s outer shape: {cute.shape(outer_layout)}") - print(f" p_smem_s outer size: {cute.size(outer_layout)}") + print(f" QK MMA shape_mnk: {qk_mma.shape_mnk}") + print(f" PV MMA shape_mnk: {pv_mma.shape_mnk}") + print(f" QK tiler: {qk_mma_tiler}") + print(f" PV tiler: {pv_mma_tiler}") + print(f" use_smem_p: {use_smem_p}") - # PV A-operand fragments (what MMA reads from TMEM or SMEM) - # For TMEM-P: tOrP reads from TMEM - # For SMEM-P: tCrP reads from SMEM, tOrP also reads from SMEM - p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) - tP_tmem = cute.make_tensor(tStS.iterator, p_tmem_s.outer) - tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer) + # QK C-fragment + qk_thr = qk_mma.get_slice(0) + qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + print(f" tStS shape: {cute.shape(tStS)} layout: {tStS.layout}") - tOrP_tmem = pv_thr.make_fragment_A(tP_tmem) - tCrP_smem = pv_mma.make_fragment_A(tP_smem) - tOrP_smem = pv_thr.make_fragment_A(tP_smem) + # PV C-fragment + pv_thr = pv_mma.get_slice(0) + pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_as) + print(f" tOtO shape: {cute.shape(tOtO)} layout: {tOtO.layout}") - print(f"\n--- PV A-operand fragments ---") - print(f" tOrP_tmem (TMEM-P A-read): shape={cute.shape(tOrP_tmem)}") - print(f" tCrP_smem (SMEM-P A-load): shape={cute.shape(tCrP_smem)} layout={tCrP_smem.layout}") - print(f" tOrP_smem (SMEM-P A-read): shape={cute.shape(tOrP_smem)} layout={tOrP_smem.layout}") - print(f" tOrP_tmem size: {cute.size(tOrP_tmem)}") - print(f" tCrP_smem size: {cute.size(tCrP_smem)}") + # PV A-operand SMEM layout (sP) + p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) + sP_alloc = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer) + print(f" sP shape: {cute.shape(sP_alloc)}") + sP_2d = cute.group_modes(sP_alloc, 0, 3) + print(f" sP_2d shape: {cute.shape(sP_2d)} rank={len(cute.shape(sP_2d))}") + print(f" sP_2d layout: {sP_2d.layout}") - # Slice tOrP the way the kernel does: [(None,None,None,0)] - tOrP_base_tmem = pv_thr.make_fragment_A(tP_tmem) - tOrP_tmem_sliced = tOrP_base_tmem[(None,None,None,0)] - tOrP_base_smem = pv_thr.make_fragment_A(tP_smem) - tOrP_smem_sliced = tOrP_base_smem[(None,None,None,0)] - print(f"\n--- PV A-operand after slice [(None,None,None,0)] ---") - print(f" tOrP_tmem_sliced: shape={cute.shape(tOrP_tmem_sliced)}") - print(f" tOrP_smem_sliced: shape={cute.shape(tOrP_smem_sliced)}") + # PV A-operand fragments + tP_tmem = cute.make_tensor(tStS.iterator, p_smem_s.outer) # dummy tmem ptr + tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer) + tOrP_tmem = pv_thr.make_fragment_A(tP_tmem) + tCrP_smem = pv_mma.make_fragment_A(tP_smem) + tOrP_smem = pv_thr.make_fragment_A(tP_smem) + print(f" tOrP_tmem shape: {cute.shape(tOrP_tmem)}") + print(f" tCrP_smem shape: {cute.shape(tCrP_smem)} layout: {tCrP_smem.layout}") + print(f" tOrP_smem shape: {cute.shape(tOrP_smem)}") - # Softmax thread partition of S (via TMEM load atoms) - sfw_idx = 0 - tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32) - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS) - thr_load = tiled_tmem_load.get_slice(sfw_idx) - tTMEM_LOADtS = thr_load.partition_S(tStS) - cS = cute.make_identity_tensor((128, 128)) - tScS = qk_thr.partition_C(cS) - tTMEM_LOADcS = thr_load.partition_D(tScS) + # Slice as kernel does + tOrP_tmem_s = tOrP_tmem[(None,None,None,0)] + tOrP_smem_s = tOrP_smem[(None,None,None,0)] + print(f" tOrP_tmem sliced: {cute.shape(tOrP_tmem_s)}") + print(f" tOrP_smem sliced: {cute.shape(tOrP_smem_s)}") - print(f"\n--- Softmax thread 0: TMEM load partition ---") - print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}") - print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}") - print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}") - print(f" => Each softmax thread has {cute.size(tTMEM_LOADcS)} S/P coordinates") + # Softmax TMEM load partition + sfw_idx = 0 + tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS) + thr_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS) + cS = cute.make_identity_tensor((128, 128)) + tScS = qk_thr.partition_C(cS) + tTMEM_LOADcS = thr_load.partition_D(tScS) + print(f" tTMEM_LOADcS (coord) shape: {cute.shape(tTMEM_LOADcS)} layout: {tTMEM_LOADcS.layout}") - # P register bridge: what P values does each softmax thread have? - # After softmax, P is in rP_bf16 which shares layout with rS - # rP_bf16 = make_tensor(recast_ptr(rP_words.iterator, dtype=BF16), tTMEM_LOADrS.layout) - # where tTMEM_LOADrS = make_rmem_tensor(tTMEM_LOADcS.shape, FP32) - p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width - tStP_layout = cute.composition(tStS.layout, cute.make_layout((pv_mma_tiler[0], p_cols_fp32))) - print(f"\n--- P columns in TMEM ---") - print(f" p_cols_fp32 (PV tiler[2] * BF16/FP32): {p_cols_fp32}") - print(f" P TMEM sub-layout (composition of tStS): {tStP_layout}") - print(f" P TMEM sub-layout shape: {cute.shape(cute.make_layout(tStP_layout))}") + # P sub-layout in TMEM + p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width + print(f" p_cols_fp32: {p_cols_fp32}") - # Now the critical question: what's the sP shape that the softmax needs to write to? - sP_alloc = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer) - sP_2d = cute.group_modes(sP_alloc, 0, 3) - print(f"\n--- sP for SMEM-P write target ---") - print(f" sP shape (3D): {cute.shape(sP_alloc)}") - print(f" sP_2d shape (grouped): {cute.shape(sP_2d)}") - print(f" sP_2d rank: {len(cute.shape(sP_2d))}") + # make_tiled_copy_C with qk_mma + print(f"\n --- make_tiled_copy_C(qk_mma) ---") + try: + smem_copy_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), BFloat16, num_bits_per_copy=128) + tiled_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) + thr_copy = tiled_copy.get_slice(sfw_idx) + print(f" OK, created tiled_copy") + try: + tD = thr_copy.partition_D(sP_2d) + print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}") + except Exception as e: + print(f" partition_D(sP_2d) FAILED: {e}") - # Try make_tiled_copy_C with QK MMA (the approach that gave rank mismatch) - print(f"\n--- make_tiled_copy_C(smem_copy_atom, qk_mma) ---") - try: - smem_copy_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - BFloat16, - num_bits_per_copy=128, - ) - tiled_copy_qk = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) - thr_qk = tiled_copy_qk.get_slice(sfw_idx) - tSMEM_QKsP = thr_qk.partition_D(sP_2d) - print(f" partition_D(sP_2d) shape: {cute.shape(tSMEM_QKsP)} rank={len(cute.shape(tSMEM_QKsP))}") + # Try source partition with coord layout + try: + rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout) + tS = thr_copy.partition_S(rP_coord) + print(f" partition_S(coord) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}") + except Exception as e: + print(f" partition_S(coord) FAILED: {e}") - # Try partition_S with various source layouts - # 1. Identity tensor (coordinate) layout - try: - rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout) - tSMEM_QKrP = thr_qk.partition_S(rP_coord) - print(f" partition_S(coord layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}") - except Exception as e: - print(f" partition_S(coord layout) FAILED: {e}") + # Try source with QK C-frag layout + try: + rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout) + tS = thr_copy.partition_S(rS=rP_qk) + print(f" partition_S(QK C-frag) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}") + except Exception as e: + print(f" partition_S(QK C-frag) FAILED: {e}") + except Exception as e: + print(f" make_tiled_copy_C FAILED: {e}") - # 2. QK C-fragment layout (tStS) - try: - rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout) - tSMEM_QKrP = thr_qk.partition_S(rP_qk) - print(f" partition_S(QK C-frag layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}") - except Exception as e: - print(f" partition_S(QK C-frag layout) FAILED: {e}") + # make_tiled_copy_C with pv_mma + print(f"\n --- make_tiled_copy_C(pv_mma) ---") + try: + tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma) + thr_copy_pv = tiled_copy_pv.get_slice(sfw_idx) + print(f" OK, created tiled_copy_pv") + try: + tD = thr_copy_pv.partition_D(sP_2d) + print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}") + except Exception as e: + print(f" partition_D(sP_2d) FAILED: {e}") + except Exception as e: + print(f" make_tiled_copy_C(pv_mma) FAILED: {e}") - # 3. The rmem layout that softmax actually has P in - rP_rmem_layout = cute.make_layout(tTMEM_LOADcS.shape) - try: - rP_rmem = cute.make_tensor(cute.find(rP_rmem_layout, 0), rP_rmem_layout) - tSMEM_QKrP = thr_qk.partition_S(rP_rmem) - print(f" partition_S(rmem layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}") - except Exception as e: - print(f" partition_S(rmem layout) FAILED: {e}") + # Compile and run + import cuda.bindings.driver as cuda + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - except Exception as e: - print(f" make_tiled_copy_C(qk_mma) FAILED: {e}") + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v.unsqueeze(-1)).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v.unsqueeze(-1))) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) - # Try make_tiled_copy_C with PV MMA (different approach) - print(f"\n--- make_tiled_copy_C(smem_copy_atom, pv_mma) ---") - try: - smem_copy_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - BFloat16, - num_bits_per_copy=128, - ) - tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma) - thr_pv = tiled_copy_pv.get_slice(sfw_idx) - tSMEM_PVsP = thr_pv.partition_D(sP_2d) - print(f" partition_D(sP_2d) shape: {cute.shape(tSMEM_PVsP)} rank={len(cute.shape(tSMEM_PVsP))}") - except Exception as e: - print(f" make_tiled_copy_C(pv_mma) FAILED: {e}") - - # Try make_tiled_copy (not _C) with explicit M/N layout - print(f"\n--- make_tiled_copy (manual M/N) ---") - try: - # The P matrix is (128, n_kv) = (M=128, N_P=128) - # Use (128, 128) tile for the copy - MN = cute.make_layout((128, 128)) - smem_copy_atom2 = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - BFloat16, - num_bits_per_copy=128, - ) - # Need to use make_tiled_copy with proper thread arrangement - # This is for the softmax warps (128 threads = 4 warps) - thr_layout = cute.make_layout((4, 32)) # 4 warps × 32 lanes - val_layout = cute.make_layout((1, 1)) # each thread copies 1 element - tiled_manual = cute.make_tiled_copy(smem_copy_atom2, thr_layout, val_layout) - thr_manual = tiled_manual.get_slice(sfw_idx) - tMANUALsP = thr_manual.partition_D(sP_2d) - print(f" partition_D(sP_2d) shape: {cute.shape(tMANUALsP)} rank={len(cute.shape(tMANUALsP))}") - except Exception as e: - print(f" make_tiled_copy (manual) FAILED: {e}") + print(' Compiling diagnostic...', flush=True) + compiled = cute.compile(_diag_kernel, mQ, mK, mV, mC) + compiled(mQ, mK, mV, mC) + torch.cuda.synchronize() if __name__ == '__main__': - analyze_layouts(64) # TMEM-P (works) - analyze_layouts(256) # SMEM-P (broken) + main()