diff --git a/tests/unit/test_d1_diag.py b/tests/unit/test_d1_diag.py new file mode 100644 index 00000000..b7c548db --- /dev/null +++ b/tests/unit/test_d1_diag.py @@ -0,0 +1,92 @@ +"""D1 diagnostic: Print TMEM budget and shapes at hd=256.""" +import torch, math +import cutlass, cutlass.cute as cute +import cutlass.torch as ct +import cutlass.utils as utils +from cutlass.cute.nvgpu import tcgen05 +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +hd = 256 +n_kv = 128 +m = 128 + +torch.manual_seed(42) +q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') +k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda') +v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda') +v_kernel = v.unsqueeze(-1) +c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + +q_dtype = cutlass.BFloat16; qk_acc_dtype = cutlass.Float32 +cta_group = tcgen05.CtaGroup.ONE + +mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) +mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) +mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) +mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + +a_major = LayoutEnum.from_tensor(q).mma_major_mode() +b_major = LayoutEnum.from_tensor(k).mma_major_mode() + +pv_n_tile = min(hd, 256) +v_fmha = cute.make_tensor( + v_kernel.iterator, + cute.make_layout((pv_n_tile, n_kv, 1), stride=(1, pv_n_tile, pv_n_tile * n_kv)), +) +v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() + +print(f"a_major={a_major}, b_major={b_major}, v_major={v_major}") +print(f"Q shape: {q.shape}, K shape: {k.shape}, V shape: {v.shape}") +print(f"mQ shape: {mQ.shape}, mK shape: {mK.shape}, mV shape: {mV.shape}, mC shape: {mC.shape}") + +qk_mma = utils.sm100.make_trivial_tiled_mma(q_dtype, q_dtype, a_major, b_major, qk_acc_dtype, cta_group, (128, 128), tcgen05.OperandSource.SMEM) +print(f"\nQK MMA shape_mnk: {qk_mma.shape_mnk}") + +pv_mma = utils.sm100.make_trivial_tiled_mma(q_dtype, q_dtype, cutlass.nvgpu.OperandMajorMode.K, v_major, qk_acc_dtype, cta_group, (128, pv_n_tile), tcgen05.OperandSource.TMEM) +print(f"PV MMA shape_mnk: {pv_mma.shape_mnk}") + +qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) +qk_mma_tiler = (128, 128, qk_ik * 4) +pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) +pv_k_inner = pv_ik * (128 // pv_ik) +pv_mma_tiler = (128, pv_n_tile, pv_k_inner) +print(f"qk_ik={qk_ik}, qk_mma_tiler={qk_mma_tiler}") +print(f"pv_ik={pv_ik}, pv_mma_tiler={pv_mma_tiler}") + +# GMEM tile views +gQ = cute.local_tile(mQ, cute.slice_(qk_mma_tiler,(None,0,None)),(None,None,None)) +gK = cute.local_tile(mK, cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None)) +gV = cute.local_tile(mV, cute.slice_(pv_mma_tiler,(0,None,None)),(None,None,None)) +gC = cute.local_tile(mC, cute.slice_(pv_mma_tiler,(None,None,0)),(None,None,None)) +print(f"\ngQ shape: {gQ.shape}") +print(f"gK shape: {gK.shape}") +print(f"gV shape: {gV.shape}") +print(f"gC shape: {gC.shape}") +print(f"n_kv_tiles: {cute.size(gK, mode=[3])}") + +# TMEM budget +qk_thr = qk_mma.get_slice(0) +qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2]) +tStS = qk_thr.make_fragment_C(qk_as) +pv_thr = pv_mma.get_slice(0) +pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2]) +tOtO = pv_thr.make_fragment_C(pv_as) + +tmem_p0_offset = 32 +p_cols_fp32 = pv_mma_tiler[2] * q_dtype.width // qk_acc_dtype.width +o_cols = find_tmem_tensor_col_offset(tOtO) +s_cols = qk_mma_tiler[1] +p_end = tmem_p0_offset + p_cols_fp32 +o_after = max(s_cols, p_end) +tmem_o0_offset = ((o_after + 31) // 32) * 32 +total = tmem_o0_offset + o_cols + +print(f"\n=== TMEM Budget (hd={hd}) ===") +print(f"S={s_cols}, P_fp32={p_cols_fp32}, P_end={p_end}, O_offset={tmem_o0_offset}, O_cols={o_cols}") +print(f"Total={total}, Fits={total<=512}") + +# QK C-fragment info +print(f"\ntStS shape: {tStS.shape}, cosize: {cute.cosize(tStS.layout)}") +print(f"tOtO shape: {tOtO.shape}, cosize: {cute.cosize(tOtO.layout)}")