D2: simpler shape diagnostic using CuTe from Python (no kernel needed)
This commit is contained in:
@@ -120,42 +120,25 @@ class ShapeDiagKernel:
|
||||
print(f"tSgK shape: {cute.shape(tSgK)}")
|
||||
print(f"tSgV shape: {cute.shape(tSgV)}")
|
||||
|
||||
# === tma_partition ===
|
||||
sQ = cute.make_tensor(BFloat16, q_smem_s.outer)
|
||||
sK = cute.make_tensor(BFloat16, k_smem_s.outer)
|
||||
sV = cute.make_tensor(BFloat16, v_smem_s.outer)
|
||||
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, 0, None, 0)).shape)
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, None, 0, 0)).shape)
|
||||
|
||||
tQsQ, tQgQ = cpasync.tma_partition(
|
||||
tma_q, 0, a_lay,
|
||||
cute.group_modes(sQ, 0, 3), cute.group_modes(tSgQ, 0, 3),
|
||||
)
|
||||
tKsK, tKgK = cpasync.tma_partition(
|
||||
tma_k, 0, b_lay,
|
||||
cute.group_modes(sK, 0, 3), cute.group_modes(tSgK, 0, 3),
|
||||
)
|
||||
tVsV, tVgV = cpasync.tma_partition(
|
||||
tma_v, 0, b_lay,
|
||||
cute.group_modes(sV, 0, 3), cute.group_modes(tSgV, 0, 3),
|
||||
)
|
||||
|
||||
print(f"tQgQ shape: {cute.shape(tQgQ)}")
|
||||
print(f"tKgK shape: {cute.shape(tKgK)}")
|
||||
print(f"tVgV shape: {cute.shape(tVgV)}")
|
||||
|
||||
# === Try slicing patterns ===
|
||||
# The original code uses (None,0,None,0) on 4-mode tensors.
|
||||
# flat_divide produces MORE modes. Let's see what we get.
|
||||
print(f"tQgQ mode count: {len(cute.shape(tQgQ))}")
|
||||
print(f"tKgK mode count: {len(cute.shape(tKgK))}")
|
||||
print(f"tVgV mode count: {len(cute.shape(tVgV))}")
|
||||
|
||||
# Try CUTLASS-style indexing:
|
||||
# tQgQ[None, None, 0, coord] where coord = (head, h_k, batch)
|
||||
# But we need to know the mode count to construct the right index.
|
||||
# Let's try various slicing patterns and see what compiles.
|
||||
# Skip tma_partition for now — we have the flat_divide shapes.
|
||||
# The key insight is the mode structure after flat_divide:
|
||||
# gQ: (tile_M, tile_K, rest_M, rest_K, ((h_r, h_k), batch))
|
||||
# After MMA partition: (thread, tile_M, tile_K, rest_M, rest_K, ((h_r, h_k), batch))
|
||||
#
|
||||
# For n_h=1: ((1,1),1) means rest head/batch is trivial
|
||||
# For n_h=2: ((2,1),1) means we can index head via bidy
|
||||
#
|
||||
# The CUTLASS reference indexes as:
|
||||
# tQgQ[None, None, 0, curr_block_coord_q[2]]
|
||||
# where 0 is rest_M and curr_block_coord_q[2] is ((head, h_k), batch)
|
||||
#
|
||||
# For our grid (m_tile, head, batch):
|
||||
# bidx = m_tile, bidy = head, bidz = batch
|
||||
# Index: tQgQ[None, k_sub, bidx, (bidy, 0, bidz)]
|
||||
#
|
||||
# But we need tma_partition shapes to confirm. Let me skip for now
|
||||
# and write the actual kernel with the CUTLASS pattern.
|
||||
print(f"DIAG: flat_divide mode structure confirmed.")
|
||||
|
||||
|
||||
def test_shapes(hd=64, n_h=1, batch=1, T=128, s_k=128):
|
||||
|
||||
98
tests/unit/test_d2_shape_diag.py
Normal file
98
tests/unit/test_d2_shape_diag.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
D2 diagnostic: Print flat_divide + tma_partition shapes for multi-CTA FMHA.
|
||||
"""
|
||||
import torch, math, cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
|
||||
|
||||
def test_shapes(hd=64, n_h=1, batch=1, T=128, s_k=128):
|
||||
"""Print flat_divide shapes for the FMHA GMEM tensors."""
|
||||
print(f"\n--- hd={hd}, n_h={n_h}, batch={batch}, T={T}, s_k={s_k} ---")
|
||||
torch.manual_seed(42)
|
||||
|
||||
q = torch.randn(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
q_cute = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
k_cute = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
v_cute = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
|
||||
# CUTLASS-style layouts
|
||||
h_r = n_h; h_k = 1
|
||||
|
||||
mQ = cute.make_tensor(q_cute.iterator, cute.make_layout(
|
||||
(T, hd, ((h_r, h_k), batch)),
|
||||
stride=(hd * h_r * h_k, 1, ((hd, hd * h_r), hd * h_r * h_k * T)),
|
||||
))
|
||||
mK = cute.make_tensor(k_cute.iterator, cute.make_layout(
|
||||
(s_k, hd, ((h_r, h_k), batch)),
|
||||
stride=(hd * h_k, 1, ((0, hd), hd * h_k * s_k)),
|
||||
))
|
||||
mV = cute.make_tensor(v_cute.iterator, cute.make_layout(
|
||||
(hd, s_k, ((h_r, h_k), batch)),
|
||||
stride=(1, hd * h_k, ((0, hd), hd * h_k * s_k)),
|
||||
))
|
||||
|
||||
print(f"mQ shape: {cute.shape(mQ)}")
|
||||
print(f"mK shape: {cute.shape(mK)}")
|
||||
print(f"mV shape: {cute.shape(mV)}")
|
||||
|
||||
# Major modes
|
||||
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
|
||||
print(f"a_major={a_major}, b_major={b_major}, v_major={v_major}")
|
||||
|
||||
# flat_divide
|
||||
qk_mma_tiler = (128, 128, min(hd, 256))
|
||||
pv_n_tile = min(hd, 256)
|
||||
pv_ik = 8 # typical for hd=64
|
||||
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
|
||||
gQ = cute.flat_divide(mQ, cute.select(qk_mma_tiler, mode=[0, 2]))
|
||||
gK = cute.flat_divide(mK, cute.select(qk_mma_tiler, mode=[1, 2]))
|
||||
gV = cute.flat_divide(mV, cute.select(pv_mma_tiler, mode=[1, 2]))
|
||||
|
||||
print(f"gQ shape: {cute.shape(gQ)}")
|
||||
print(f"gK shape: {cute.shape(gK)}")
|
||||
print(f"gV shape: {cute.shape(gV)}")
|
||||
|
||||
# n_kv_tiles from gK
|
||||
# After flat_divide(mK, (128, k_tile)), the rest dims encode number of N-tiles
|
||||
# Mode structure: (tile_N, tile_K, rest_N, rest_K, ((h_r, h_k), batch))
|
||||
# n_kv_tiles = rest_N dimension size
|
||||
gK_shape = cute.shape(gK)
|
||||
print(f"gK shape details: {gK_shape}")
|
||||
# Try to compute n_kv_tiles
|
||||
# For s_k=128, k_tile=64: rest_N = ceil(128/128) = 1
|
||||
# For s_k=256, k_tile=64: rest_N = ceil(256/128) = 2
|
||||
print(f"n_kv_tiles would be: {s_k // 128}")
|
||||
|
||||
# Also test with the O tensor
|
||||
o_cute = ct.from_dlpack(torch.zeros(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')).mark_layout_dynamic(leading_dim=ct.get_leading_dim(torch.zeros(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')))
|
||||
mO = cute.make_tensor(o_cute.iterator, cute.make_layout(
|
||||
(T, hd, ((h_r, h_k), batch)),
|
||||
stride=(hd * h_r * h_k, 1, ((hd, hd * h_r), hd * h_r * h_k * T)),
|
||||
))
|
||||
|
||||
epi_tile = (128, min(hd, 256))
|
||||
gO = cute.flat_divide(mO, epi_tile)
|
||||
print(f"gO shape: {cute.shape(gO)}")
|
||||
|
||||
|
||||
def test():
|
||||
print("=== D2 flat_divide shape diagnostic ===")
|
||||
test_shapes(64, 1, 1, 128, 128)
|
||||
test_shapes(64, 2, 1, 128, 128)
|
||||
test_shapes(64, 4, 2, 128, 128)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user