D2: add flat_divide shape diagnostic kernel for multi-CTA grid

This commit is contained in:
2026-05-25 02:33:15 +00:00
parent 32850f6974
commit 7599801f57
2 changed files with 278 additions and 0 deletions

View File

@@ -0,0 +1,188 @@
"""
D2 diagnostic: Print flat_divide + tma_partition shapes for multi-CTA FMHA.
Runs inside @cute.kernel to get trace-time shapes.
This tells us exactly how to index the partitioned GMEM tensors
with runtime block_idx coordinates.
Test with: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_flat_divide_diag.py
"""
import torch, math, cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.utils.sm100 as sm100
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
import cutlass.torch as ct
class ShapeDiagKernel:
def __init__(self, hd=64, n_h=1, batch=1, s_k=128):
self.hd = hd; self.n_h = n_h; self.batch = batch; self.s_k = s_k
self.q_dtype = BFloat16
self.cta_group = tcgen05.CtaGroup.ONE
self.cluster_shape_mn = (1, 1)
@cute.jit
def __call__(self, q, k, v, stream):
n_h = self.n_h; h_r = n_h; h_k = 1; hd = self.hd
batch = self.batch; s_k_len = self.s_k; T = q.shape[2]
# CUTLASS-style GMEM tensor layouts
mQ = cute.make_tensor(q.iterator, cute.make_layout(
(T, hd, ((h_r, h_k), batch)),
stride=(hd * h_r * h_k, 1, ((hd, hd * h_r), hd * h_r * h_k * T)),
))
mK = cute.make_tensor(k.iterator, cute.make_layout(
(s_k_len, hd, ((h_r, h_k), batch)),
stride=(hd * h_k, 1, ((0, hd), hd * h_k * s_k_len)),
))
mV = cute.make_tensor(v.iterator, cute.make_layout(
(hd, s_k_len, ((h_r, h_k), batch)),
stride=(1, hd * h_k, ((0, hd), hd * h_k * s_k_len)),
))
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
qk_mma = sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, a_major, b_major,
Float32, self.cta_group, (128, 128), tcgen05.OperandSource.SMEM,
)
pv_mma = sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major,
Float32, self.cta_group, (128, min(hd, 256)), tcgen05.OperandSource.TMEM,
)
qk_mma_tiler = (128, 128, min(hd, 256))
pv_n_tile = min(hd, 256)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
q_smem_s = sm100.make_smem_layout_a(qk_mma, qk_mma_tiler, self.q_dtype, 1)
k_smem_s = sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, self.q_dtype, 1)
v_smem_s = sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, self.q_dtype, 1)
cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,)
)
tma_q, tma_mQ = cute.nvgpu.make_tiled_tma_atom_A(
sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
mQ, cute.select(q_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape,
)
tma_k, tma_mK = cute.nvgpu.make_tiled_tma_atom_B(
sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
mK, cute.select(k_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape,
)
tma_v, tma_mV = cute.nvgpu.make_tiled_tma_atom_B(
sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
mV, cute.select(v_smem_s, mode=[0, 1, 2]), pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape,
)
self._diag(tma_mQ, tma_mK, tma_mV, tma_q, tma_k, tma_v,
qk_mma, pv_mma, cluster_layout_vmnk,
q_smem_s, k_smem_s, v_smem_s, qk_mma_tiler, pv_mma_tiler,
).launch(grid=(1, 1, 1), block=[192, 1, 1], stream=stream)
@cute.kernel
def _diag(self, mQ, mK, mV, tma_q, tma_k, tma_v,
qk_mma, pv_mma, cl_vmnk,
q_smem_s, k_smem_s, v_smem_s, qk_mma_tiler, pv_mma_tiler):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
if warp_idx == 5:
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
# === Print full GMEM tensor shapes ===
print(f"mQ shape: {cute.shape(mQ)}")
print(f"mK shape: {cute.shape(mK)}")
print(f"mV shape: {cute.shape(mV)}")
# === flat_divide ===
gQ = cute.flat_divide(mQ, cute.select(qk_mma_tiler, mode=[0, 2]))
gK = cute.flat_divide(mK, cute.select(qk_mma_tiler, mode=[1, 2]))
gV = cute.flat_divide(mV, cute.select(pv_mma_tiler, mode=[1, 2]))
print(f"gQ shape: {cute.shape(gQ)}")
print(f"gK shape: {cute.shape(gK)}")
print(f"gV shape: {cute.shape(gV)}")
# === MMA partition ===
tSgQ = qk_thr.partition_A(gQ)
tSgK = qk_thr.partition_B(gK)
tSgV = pv_thr.partition_B(gV)
print(f"tSgQ shape: {cute.shape(tSgQ)}")
print(f"tSgK shape: {cute.shape(tSgK)}")
print(f"tSgV shape: {cute.shape(tSgV)}")
# === tma_partition ===
sQ = cute.make_tensor(BFloat16, q_smem_s.outer)
sK = cute.make_tensor(BFloat16, k_smem_s.outer)
sV = cute.make_tensor(BFloat16, v_smem_s.outer)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, 0, None, 0)).shape)
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, None, 0, 0)).shape)
tQsQ, tQgQ = cpasync.tma_partition(
tma_q, 0, a_lay,
cute.group_modes(sQ, 0, 3), cute.group_modes(tSgQ, 0, 3),
)
tKsK, tKgK = cpasync.tma_partition(
tma_k, 0, b_lay,
cute.group_modes(sK, 0, 3), cute.group_modes(tSgK, 0, 3),
)
tVsV, tVgV = cpasync.tma_partition(
tma_v, 0, b_lay,
cute.group_modes(sV, 0, 3), cute.group_modes(tSgV, 0, 3),
)
print(f"tQgQ shape: {cute.shape(tQgQ)}")
print(f"tKgK shape: {cute.shape(tKgK)}")
print(f"tVgV shape: {cute.shape(tVgV)}")
# === Try slicing patterns ===
# The original code uses (None,0,None,0) on 4-mode tensors.
# flat_divide produces MORE modes. Let's see what we get.
print(f"tQgQ mode count: {len(cute.shape(tQgQ))}")
print(f"tKgK mode count: {len(cute.shape(tKgK))}")
print(f"tVgV mode count: {len(cute.shape(tVgV))}")
# Try CUTLASS-style indexing:
# tQgQ[None, None, 0, coord] where coord = (head, h_k, batch)
# But we need to know the mode count to construct the right index.
# Let's try various slicing patterns and see what compiles.
def test_shapes(hd=64, n_h=1, batch=1, T=128, s_k=128):
print(f"\n--- hd={hd}, n_h={n_h}, batch={batch}, T={T}, s_k={s_k} ---")
q = torch.randn(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
q_cute = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
k_cute = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_cute = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
kernel = ShapeDiagKernel(hd=hd, n_h=n_h, batch=batch, s_k=s_k)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print(f'Compiling (hd={hd}, n_h={n_h})...', flush=True)
compiled = cute.compile(kernel, q_cute, k_cute, v_cute, stream)
compiled(q_cute, k_cute, v_cute, stream)
torch.cuda.synchronize()
print('Done.')
def test():
print("=== D2 flat_divide shape diagnostic ===")
test_shapes(64, 1, 1, 128, 128)
test_shapes(64, 2, 1, 128, 128)
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,90 @@
"""
FMHA D2: Multi-CTA grid with flat_divide + tma_partition inside kernel.
Proper Blackwell approach following CUTLASS reference pattern:
- Q/K/V/O tensors with embedded head dimensions: (s, d, ((h_r, h_k), batch))
- flat_divide + tma_partition inside warp blocks (runtime block_idx)
- Direct TMA bulk copy for O output (not epilogue_tma_store)
- Grid: (M_tiles, h_q, batch) — one CTA per (M-tile, head, batch) triple
DSV4 is MQA: h_r = num_query_heads, h_k = 1, K/V shared.
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def reference_attention(q, k, v, scale):
"""FP32 reference attention. q: (batch, n_h, T, hd), k/v: (batch, s_k, hd)"""
qf = q.float()
kf = k.float()
vf = v.float()
batch, n_h, T, hd = q.shape
s_k = k.shape[1]
ref = torch.zeros_like(qf)
for b in range(batch):
for h in range(n_h):
attn = qf[b, h] @ kf[b].T * scale
attn_max = attn.max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(attn - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref[b, h] = (attn_exp / attn_sum) @ vf[b]
return ref
def test_multicta(hd=64, n_h=1, batch=1, T=128, s_k=128):
"""Test multi-CTA grid FMHA with n_h query heads (MQA)."""
torch.manual_seed(42)
scale = 1.0 / math.sqrt(hd)
q = torch.randn(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
o = torch.zeros(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')
lse = torch.zeros(batch, n_h, T, dtype=torch.float32, device='cuda')
# FP32 reference
ref = reference_attention(q, k, v, scale)
# Run kernel with multi-CTA grid
kernel = FmhaKernel(
head_dim=hd, s_k=s_k, num_query_heads=n_h, batch_size=batch,
use_smem_p=(hd > 64), normalize=True,
)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel(q, k, v, o, stream, lse=lse)
torch.cuda.synchronize()
cos = torch.nn.functional.cosine_similarity(
o.flatten().float().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
status = "PASS" if cos >= 0.99 else "FAIL"
print(f' hd={hd}, n_h={n_h}, batch={batch}, T={T}, s_k={s_k}: cos {cos:.6f} {status}')
return cos >= 0.99
def test():
print("=== D2: Multi-CTA Grid (flat_divide approach) ===\n")
all_pass = True
# n_h=1 regression (should match existing single-head behavior)
all_pass &= test_multicta(64, 1, 1, 128, 128)
# n_h=2, single batch
all_pass &= test_multicta(64, 2, 1, 128, 128)
# n_h=4, batch=2
all_pass &= test_multicta(64, 4, 2, 128, 128)
# n_h=8
all_pass &= test_multicta(64, 8, 1, 128, 128)
print(f'\nOverall: {"ALL PASS" if all_pass else "SOME FAILED"}')
if __name__ == '__main__':
test()