D2: simpler shape diagnostic using CuTe from Python (no kernel needed)

This commit is contained in:
2026-05-25 02:36:41 +00:00
parent 684e9a85fe
commit d5b69ac122
2 changed files with 117 additions and 36 deletions

View File

@@ -120,42 +120,25 @@ class ShapeDiagKernel:
print(f"tSgK shape: {cute.shape(tSgK)}")
print(f"tSgV shape: {cute.shape(tSgV)}")
# === tma_partition ===
sQ = cute.make_tensor(BFloat16, q_smem_s.outer)
sK = cute.make_tensor(BFloat16, k_smem_s.outer)
sV = cute.make_tensor(BFloat16, v_smem_s.outer)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, 0, None, 0)).shape)
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, None, 0, 0)).shape)
tQsQ, tQgQ = cpasync.tma_partition(
tma_q, 0, a_lay,
cute.group_modes(sQ, 0, 3), cute.group_modes(tSgQ, 0, 3),
)
tKsK, tKgK = cpasync.tma_partition(
tma_k, 0, b_lay,
cute.group_modes(sK, 0, 3), cute.group_modes(tSgK, 0, 3),
)
tVsV, tVgV = cpasync.tma_partition(
tma_v, 0, b_lay,
cute.group_modes(sV, 0, 3), cute.group_modes(tSgV, 0, 3),
)
print(f"tQgQ shape: {cute.shape(tQgQ)}")
print(f"tKgK shape: {cute.shape(tKgK)}")
print(f"tVgV shape: {cute.shape(tVgV)}")
# === Try slicing patterns ===
# The original code uses (None,0,None,0) on 4-mode tensors.
# flat_divide produces MORE modes. Let's see what we get.
print(f"tQgQ mode count: {len(cute.shape(tQgQ))}")
print(f"tKgK mode count: {len(cute.shape(tKgK))}")
print(f"tVgV mode count: {len(cute.shape(tVgV))}")
# Try CUTLASS-style indexing:
# tQgQ[None, None, 0, coord] where coord = (head, h_k, batch)
# But we need to know the mode count to construct the right index.
# Let's try various slicing patterns and see what compiles.
# Skip tma_partition for now — we have the flat_divide shapes.
# The key insight is the mode structure after flat_divide:
# gQ: (tile_M, tile_K, rest_M, rest_K, ((h_r, h_k), batch))
# After MMA partition: (thread, tile_M, tile_K, rest_M, rest_K, ((h_r, h_k), batch))
#
# For n_h=1: ((1,1),1) means rest head/batch is trivial
# For n_h=2: ((2,1),1) means we can index head via bidy
#
# The CUTLASS reference indexes as:
# tQgQ[None, None, 0, curr_block_coord_q[2]]
# where 0 is rest_M and curr_block_coord_q[2] is ((head, h_k), batch)
#
# For our grid (m_tile, head, batch):
# bidx = m_tile, bidy = head, bidz = batch
# Index: tQgQ[None, k_sub, bidx, (bidy, 0, bidz)]
#
# But we need tma_partition shapes to confirm. Let me skip for now
# and write the actual kernel with the CUTLASS pattern.
print(f"DIAG: flat_divide mode structure confirmed.")
def test_shapes(hd=64, n_h=1, batch=1, T=128, s_k=128):

View File

@@ -0,0 +1,98 @@
"""
D2 diagnostic: Print flat_divide + tma_partition shapes for multi-CTA FMHA.
"""
import torch, math, cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
import cutlass.torch as ct
def test_shapes(hd=64, n_h=1, batch=1, T=128, s_k=128):
"""Print flat_divide shapes for the FMHA GMEM tensors."""
print(f"\n--- hd={hd}, n_h={n_h}, batch={batch}, T={T}, s_k={s_k} ---")
torch.manual_seed(42)
q = torch.randn(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
q_cute = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
k_cute = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_cute = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
# CUTLASS-style layouts
h_r = n_h; h_k = 1
mQ = cute.make_tensor(q_cute.iterator, cute.make_layout(
(T, hd, ((h_r, h_k), batch)),
stride=(hd * h_r * h_k, 1, ((hd, hd * h_r), hd * h_r * h_k * T)),
))
mK = cute.make_tensor(k_cute.iterator, cute.make_layout(
(s_k, hd, ((h_r, h_k), batch)),
stride=(hd * h_k, 1, ((0, hd), hd * h_k * s_k)),
))
mV = cute.make_tensor(v_cute.iterator, cute.make_layout(
(hd, s_k, ((h_r, h_k), batch)),
stride=(1, hd * h_k, ((0, hd), hd * h_k * s_k)),
))
print(f"mQ shape: {cute.shape(mQ)}")
print(f"mK shape: {cute.shape(mK)}")
print(f"mV shape: {cute.shape(mV)}")
# Major modes
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
print(f"a_major={a_major}, b_major={b_major}, v_major={v_major}")
# flat_divide
qk_mma_tiler = (128, 128, min(hd, 256))
pv_n_tile = min(hd, 256)
pv_ik = 8 # typical for hd=64
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
gQ = cute.flat_divide(mQ, cute.select(qk_mma_tiler, mode=[0, 2]))
gK = cute.flat_divide(mK, cute.select(qk_mma_tiler, mode=[1, 2]))
gV = cute.flat_divide(mV, cute.select(pv_mma_tiler, mode=[1, 2]))
print(f"gQ shape: {cute.shape(gQ)}")
print(f"gK shape: {cute.shape(gK)}")
print(f"gV shape: {cute.shape(gV)}")
# n_kv_tiles from gK
# After flat_divide(mK, (128, k_tile)), the rest dims encode number of N-tiles
# Mode structure: (tile_N, tile_K, rest_N, rest_K, ((h_r, h_k), batch))
# n_kv_tiles = rest_N dimension size
gK_shape = cute.shape(gK)
print(f"gK shape details: {gK_shape}")
# Try to compute n_kv_tiles
# For s_k=128, k_tile=64: rest_N = ceil(128/128) = 1
# For s_k=256, k_tile=64: rest_N = ceil(256/128) = 2
print(f"n_kv_tiles would be: {s_k // 128}")
# Also test with the O tensor
o_cute = ct.from_dlpack(torch.zeros(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')).mark_layout_dynamic(leading_dim=ct.get_leading_dim(torch.zeros(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')))
mO = cute.make_tensor(o_cute.iterator, cute.make_layout(
(T, hd, ((h_r, h_k), batch)),
stride=(hd * h_r * h_k, 1, ((hd, hd * h_r), hd * h_r * h_k * T)),
))
epi_tile = (128, min(hd, 256))
gO = cute.flat_divide(mO, epi_tile)
print(f"gO shape: {cute.shape(gO)}")
def test():
print("=== D2 flat_divide shape diagnostic ===")
test_shapes(64, 1, 1, 128, 128)
test_shapes(64, 2, 1, 128, 128)
test_shapes(64, 4, 2, 128, 128)
if __name__ == '__main__':
test()