From d5b69ac12239e7372b5a056aadbf2ec0615ff995 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 25 May 2026 02:36:41 +0000 Subject: [PATCH] D2: simpler shape diagnostic using CuTe from Python (no kernel needed) --- tests/unit/test_d2_flat_divide_diag.py | 55 +++++---------- tests/unit/test_d2_shape_diag.py | 98 ++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 36 deletions(-) create mode 100644 tests/unit/test_d2_shape_diag.py diff --git a/tests/unit/test_d2_flat_divide_diag.py b/tests/unit/test_d2_flat_divide_diag.py index b2a611f3..bf4a820b 100644 --- a/tests/unit/test_d2_flat_divide_diag.py +++ b/tests/unit/test_d2_flat_divide_diag.py @@ -120,42 +120,25 @@ class ShapeDiagKernel: print(f"tSgK shape: {cute.shape(tSgK)}") print(f"tSgV shape: {cute.shape(tSgV)}") - # === tma_partition === - sQ = cute.make_tensor(BFloat16, q_smem_s.outer) - sK = cute.make_tensor(BFloat16, k_smem_s.outer) - sV = cute.make_tensor(BFloat16, v_smem_s.outer) - - a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, 0, None, 0)).shape) - b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, None, 0, 0)).shape) - - tQsQ, tQgQ = cpasync.tma_partition( - tma_q, 0, a_lay, - cute.group_modes(sQ, 0, 3), cute.group_modes(tSgQ, 0, 3), - ) - tKsK, tKgK = cpasync.tma_partition( - tma_k, 0, b_lay, - cute.group_modes(sK, 0, 3), cute.group_modes(tSgK, 0, 3), - ) - tVsV, tVgV = cpasync.tma_partition( - tma_v, 0, b_lay, - cute.group_modes(sV, 0, 3), cute.group_modes(tSgV, 0, 3), - ) - - print(f"tQgQ shape: {cute.shape(tQgQ)}") - print(f"tKgK shape: {cute.shape(tKgK)}") - print(f"tVgV shape: {cute.shape(tVgV)}") - - # === Try slicing patterns === - # The original code uses (None,0,None,0) on 4-mode tensors. - # flat_divide produces MORE modes. Let's see what we get. - print(f"tQgQ mode count: {len(cute.shape(tQgQ))}") - print(f"tKgK mode count: {len(cute.shape(tKgK))}") - print(f"tVgV mode count: {len(cute.shape(tVgV))}") - - # Try CUTLASS-style indexing: - # tQgQ[None, None, 0, coord] where coord = (head, h_k, batch) - # But we need to know the mode count to construct the right index. - # Let's try various slicing patterns and see what compiles. + # Skip tma_partition for now — we have the flat_divide shapes. + # The key insight is the mode structure after flat_divide: + # gQ: (tile_M, tile_K, rest_M, rest_K, ((h_r, h_k), batch)) + # After MMA partition: (thread, tile_M, tile_K, rest_M, rest_K, ((h_r, h_k), batch)) + # + # For n_h=1: ((1,1),1) means rest head/batch is trivial + # For n_h=2: ((2,1),1) means we can index head via bidy + # + # The CUTLASS reference indexes as: + # tQgQ[None, None, 0, curr_block_coord_q[2]] + # where 0 is rest_M and curr_block_coord_q[2] is ((head, h_k), batch) + # + # For our grid (m_tile, head, batch): + # bidx = m_tile, bidy = head, bidz = batch + # Index: tQgQ[None, k_sub, bidx, (bidy, 0, bidz)] + # + # But we need tma_partition shapes to confirm. Let me skip for now + # and write the actual kernel with the CUTLASS pattern. + print(f"DIAG: flat_divide mode structure confirmed.") def test_shapes(hd=64, n_h=1, batch=1, T=128, s_k=128): diff --git a/tests/unit/test_d2_shape_diag.py b/tests/unit/test_d2_shape_diag.py new file mode 100644 index 00000000..4f7edc19 --- /dev/null +++ b/tests/unit/test_d2_shape_diag.py @@ -0,0 +1,98 @@ +""" +D2 diagnostic: Print flat_divide + tma_partition shapes for multi-CTA FMHA. +""" +import torch, math, cutlass +import cutlass.cute as cute +import cutlass.utils as utils +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, const_expr +from cutlass.utils import LayoutEnum +import cuda.bindings.driver as cuda +import cutlass.torch as ct + + +def test_shapes(hd=64, n_h=1, batch=1, T=128, s_k=128): + """Print flat_divide shapes for the FMHA GMEM tensors.""" + print(f"\n--- hd={hd}, n_h={n_h}, batch={batch}, T={T}, s_k={s_k} ---") + torch.manual_seed(42) + + q = torch.randn(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda') + k = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda') + v = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda') + + q_cute = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + k_cute = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + v_cute = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) + + # CUTLASS-style layouts + h_r = n_h; h_k = 1 + + mQ = cute.make_tensor(q_cute.iterator, cute.make_layout( + (T, hd, ((h_r, h_k), batch)), + stride=(hd * h_r * h_k, 1, ((hd, hd * h_r), hd * h_r * h_k * T)), + )) + mK = cute.make_tensor(k_cute.iterator, cute.make_layout( + (s_k, hd, ((h_r, h_k), batch)), + stride=(hd * h_k, 1, ((0, hd), hd * h_k * s_k)), + )) + mV = cute.make_tensor(v_cute.iterator, cute.make_layout( + (hd, s_k, ((h_r, h_k), batch)), + stride=(1, hd * h_k, ((0, hd), hd * h_k * s_k)), + )) + + print(f"mQ shape: {cute.shape(mQ)}") + print(f"mK shape: {cute.shape(mK)}") + print(f"mV shape: {cute.shape(mV)}") + + # Major modes + a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() + b_major = LayoutEnum.from_tensor(mK).mma_major_mode() + v_major = LayoutEnum.from_tensor(mV).mma_major_mode() + print(f"a_major={a_major}, b_major={b_major}, v_major={v_major}") + + # flat_divide + qk_mma_tiler = (128, 128, min(hd, 256)) + pv_n_tile = min(hd, 256) + pv_ik = 8 # typical for hd=64 + pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik)) + + gQ = cute.flat_divide(mQ, cute.select(qk_mma_tiler, mode=[0, 2])) + gK = cute.flat_divide(mK, cute.select(qk_mma_tiler, mode=[1, 2])) + gV = cute.flat_divide(mV, cute.select(pv_mma_tiler, mode=[1, 2])) + + print(f"gQ shape: {cute.shape(gQ)}") + print(f"gK shape: {cute.shape(gK)}") + print(f"gV shape: {cute.shape(gV)}") + + # n_kv_tiles from gK + # After flat_divide(mK, (128, k_tile)), the rest dims encode number of N-tiles + # Mode structure: (tile_N, tile_K, rest_N, rest_K, ((h_r, h_k), batch)) + # n_kv_tiles = rest_N dimension size + gK_shape = cute.shape(gK) + print(f"gK shape details: {gK_shape}") + # Try to compute n_kv_tiles + # For s_k=128, k_tile=64: rest_N = ceil(128/128) = 1 + # For s_k=256, k_tile=64: rest_N = ceil(256/128) = 2 + print(f"n_kv_tiles would be: {s_k // 128}") + + # Also test with the O tensor + o_cute = ct.from_dlpack(torch.zeros(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')).mark_layout_dynamic(leading_dim=ct.get_leading_dim(torch.zeros(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda'))) + mO = cute.make_tensor(o_cute.iterator, cute.make_layout( + (T, hd, ((h_r, h_k), batch)), + stride=(hd * h_r * h_k, 1, ((hd, hd * h_r), hd * h_r * h_k * T)), + )) + + epi_tile = (128, min(hd, 256)) + gO = cute.flat_divide(mO, epi_tile) + print(f"gO shape: {cute.shape(gO)}") + + +def test(): + print("=== D2 flat_divide shape diagnostic ===") + test_shapes(64, 1, 1, 128, 128) + test_shapes(64, 2, 1, 128, 128) + test_shapes(64, 4, 2, 128, 128) + + +if __name__ == '__main__': + test()