D2: add flat_divide shape diagnostic kernel for multi-CTA grid
This commit is contained in:
188
tests/unit/test_d2_flat_divide_diag.py
Normal file
188
tests/unit/test_d2_flat_divide_diag.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
D2 diagnostic: Print flat_divide + tma_partition shapes for multi-CTA FMHA.
|
||||
|
||||
Runs inside @cute.kernel to get trace-time shapes.
|
||||
This tells us exactly how to index the partitioned GMEM tensors
|
||||
with runtime block_idx coordinates.
|
||||
|
||||
Test with: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_flat_divide_diag.py
|
||||
"""
|
||||
import torch, math, cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.utils.sm100 as sm100
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
|
||||
|
||||
class ShapeDiagKernel:
|
||||
def __init__(self, hd=64, n_h=1, batch=1, s_k=128):
|
||||
self.hd = hd; self.n_h = n_h; self.batch = batch; self.s_k = s_k
|
||||
self.q_dtype = BFloat16
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, stream):
|
||||
n_h = self.n_h; h_r = n_h; h_k = 1; hd = self.hd
|
||||
batch = self.batch; s_k_len = self.s_k; T = q.shape[2]
|
||||
|
||||
# CUTLASS-style GMEM tensor layouts
|
||||
mQ = cute.make_tensor(q.iterator, cute.make_layout(
|
||||
(T, hd, ((h_r, h_k), batch)),
|
||||
stride=(hd * h_r * h_k, 1, ((hd, hd * h_r), hd * h_r * h_k * T)),
|
||||
))
|
||||
mK = cute.make_tensor(k.iterator, cute.make_layout(
|
||||
(s_k_len, hd, ((h_r, h_k), batch)),
|
||||
stride=(hd * h_k, 1, ((0, hd), hd * h_k * s_k_len)),
|
||||
))
|
||||
mV = cute.make_tensor(v.iterator, cute.make_layout(
|
||||
(hd, s_k_len, ((h_r, h_k), batch)),
|
||||
stride=(1, hd * h_k, ((0, hd), hd * h_k * s_k_len)),
|
||||
))
|
||||
|
||||
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
v_major = LayoutEnum.from_tensor(mV).mma_major_mode()
|
||||
|
||||
qk_mma = sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype, a_major, b_major,
|
||||
Float32, self.cta_group, (128, 128), tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
pv_mma = sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major,
|
||||
Float32, self.cta_group, (128, min(hd, 256)), tcgen05.OperandSource.TMEM,
|
||||
)
|
||||
|
||||
qk_mma_tiler = (128, 128, min(hd, 256))
|
||||
pv_n_tile = min(hd, 256)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
|
||||
q_smem_s = sm100.make_smem_layout_a(qk_mma, qk_mma_tiler, self.q_dtype, 1)
|
||||
k_smem_s = sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, self.q_dtype, 1)
|
||||
v_smem_s = sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, self.q_dtype, 1)
|
||||
|
||||
cluster_layout_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,)
|
||||
)
|
||||
|
||||
tma_q, tma_mQ = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
mQ, cute.select(q_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape,
|
||||
)
|
||||
tma_k, tma_mK = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
mK, cute.select(k_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape,
|
||||
)
|
||||
tma_v, tma_mV = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
mV, cute.select(v_smem_s, mode=[0, 1, 2]), pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape,
|
||||
)
|
||||
|
||||
self._diag(tma_mQ, tma_mK, tma_mV, tma_q, tma_k, tma_v,
|
||||
qk_mma, pv_mma, cluster_layout_vmnk,
|
||||
q_smem_s, k_smem_s, v_smem_s, qk_mma_tiler, pv_mma_tiler,
|
||||
).launch(grid=(1, 1, 1), block=[192, 1, 1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _diag(self, mQ, mK, mV, tma_q, tma_k, tma_v,
|
||||
qk_mma, pv_mma, cl_vmnk,
|
||||
q_smem_s, k_smem_s, v_smem_s, qk_mma_tiler, pv_mma_tiler):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
if warp_idx == 5:
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
|
||||
# === Print full GMEM tensor shapes ===
|
||||
print(f"mQ shape: {cute.shape(mQ)}")
|
||||
print(f"mK shape: {cute.shape(mK)}")
|
||||
print(f"mV shape: {cute.shape(mV)}")
|
||||
|
||||
# === flat_divide ===
|
||||
gQ = cute.flat_divide(mQ, cute.select(qk_mma_tiler, mode=[0, 2]))
|
||||
gK = cute.flat_divide(mK, cute.select(qk_mma_tiler, mode=[1, 2]))
|
||||
gV = cute.flat_divide(mV, cute.select(pv_mma_tiler, mode=[1, 2]))
|
||||
|
||||
print(f"gQ shape: {cute.shape(gQ)}")
|
||||
print(f"gK shape: {cute.shape(gK)}")
|
||||
print(f"gV shape: {cute.shape(gV)}")
|
||||
|
||||
# === MMA partition ===
|
||||
tSgQ = qk_thr.partition_A(gQ)
|
||||
tSgK = qk_thr.partition_B(gK)
|
||||
tSgV = pv_thr.partition_B(gV)
|
||||
|
||||
print(f"tSgQ shape: {cute.shape(tSgQ)}")
|
||||
print(f"tSgK shape: {cute.shape(tSgK)}")
|
||||
print(f"tSgV shape: {cute.shape(tSgV)}")
|
||||
|
||||
# === tma_partition ===
|
||||
sQ = cute.make_tensor(BFloat16, q_smem_s.outer)
|
||||
sK = cute.make_tensor(BFloat16, k_smem_s.outer)
|
||||
sV = cute.make_tensor(BFloat16, v_smem_s.outer)
|
||||
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, 0, None, 0)).shape)
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, None, 0, 0)).shape)
|
||||
|
||||
tQsQ, tQgQ = cpasync.tma_partition(
|
||||
tma_q, 0, a_lay,
|
||||
cute.group_modes(sQ, 0, 3), cute.group_modes(tSgQ, 0, 3),
|
||||
)
|
||||
tKsK, tKgK = cpasync.tma_partition(
|
||||
tma_k, 0, b_lay,
|
||||
cute.group_modes(sK, 0, 3), cute.group_modes(tSgK, 0, 3),
|
||||
)
|
||||
tVsV, tVgV = cpasync.tma_partition(
|
||||
tma_v, 0, b_lay,
|
||||
cute.group_modes(sV, 0, 3), cute.group_modes(tSgV, 0, 3),
|
||||
)
|
||||
|
||||
print(f"tQgQ shape: {cute.shape(tQgQ)}")
|
||||
print(f"tKgK shape: {cute.shape(tKgK)}")
|
||||
print(f"tVgV shape: {cute.shape(tVgV)}")
|
||||
|
||||
# === Try slicing patterns ===
|
||||
# The original code uses (None,0,None,0) on 4-mode tensors.
|
||||
# flat_divide produces MORE modes. Let's see what we get.
|
||||
print(f"tQgQ mode count: {len(cute.shape(tQgQ))}")
|
||||
print(f"tKgK mode count: {len(cute.shape(tKgK))}")
|
||||
print(f"tVgV mode count: {len(cute.shape(tVgV))}")
|
||||
|
||||
# Try CUTLASS-style indexing:
|
||||
# tQgQ[None, None, 0, coord] where coord = (head, h_k, batch)
|
||||
# But we need to know the mode count to construct the right index.
|
||||
# Let's try various slicing patterns and see what compiles.
|
||||
|
||||
|
||||
def test_shapes(hd=64, n_h=1, batch=1, T=128, s_k=128):
|
||||
print(f"\n--- hd={hd}, n_h={n_h}, batch={batch}, T={T}, s_k={s_k} ---")
|
||||
|
||||
q = torch.randn(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
q_cute = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
k_cute = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
v_cute = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
|
||||
kernel = ShapeDiagKernel(hd=hd, n_h=n_h, batch=batch, s_k=s_k)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
print(f'Compiling (hd={hd}, n_h={n_h})...', flush=True)
|
||||
compiled = cute.compile(kernel, q_cute, k_cute, v_cute, stream)
|
||||
compiled(q_cute, k_cute, v_cute, stream)
|
||||
torch.cuda.synchronize()
|
||||
print('Done.')
|
||||
|
||||
|
||||
def test():
|
||||
print("=== D2 flat_divide shape diagnostic ===")
|
||||
test_shapes(64, 1, 1, 128, 128)
|
||||
test_shapes(64, 2, 1, 128, 128)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
90
tests/unit/test_d2_multicta.py
Normal file
90
tests/unit/test_d2_multicta.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
FMHA D2: Multi-CTA grid with flat_divide + tma_partition inside kernel.
|
||||
|
||||
Proper Blackwell approach following CUTLASS reference pattern:
|
||||
- Q/K/V/O tensors with embedded head dimensions: (s, d, ((h_r, h_k), batch))
|
||||
- flat_divide + tma_partition inside warp blocks (runtime block_idx)
|
||||
- Direct TMA bulk copy for O output (not epilogue_tma_store)
|
||||
- Grid: (M_tiles, h_q, batch) — one CTA per (M-tile, head, batch) triple
|
||||
|
||||
DSV4 is MQA: h_r = num_query_heads, h_k = 1, K/V shared.
|
||||
"""
|
||||
import torch, math
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def reference_attention(q, k, v, scale):
|
||||
"""FP32 reference attention. q: (batch, n_h, T, hd), k/v: (batch, s_k, hd)"""
|
||||
qf = q.float()
|
||||
kf = k.float()
|
||||
vf = v.float()
|
||||
batch, n_h, T, hd = q.shape
|
||||
s_k = k.shape[1]
|
||||
ref = torch.zeros_like(qf)
|
||||
for b in range(batch):
|
||||
for h in range(n_h):
|
||||
attn = qf[b, h] @ kf[b].T * scale
|
||||
attn_max = attn.max(dim=-1, keepdim=True)[0]
|
||||
attn_exp = torch.exp(attn - attn_max)
|
||||
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
|
||||
ref[b, h] = (attn_exp / attn_sum) @ vf[b]
|
||||
return ref
|
||||
|
||||
|
||||
def test_multicta(hd=64, n_h=1, batch=1, T=128, s_k=128):
|
||||
"""Test multi-CTA grid FMHA with n_h query heads (MQA)."""
|
||||
torch.manual_seed(42)
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
q = torch.randn(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
o = torch.zeros(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
||||
lse = torch.zeros(batch, n_h, T, dtype=torch.float32, device='cuda')
|
||||
|
||||
# FP32 reference
|
||||
ref = reference_attention(q, k, v, scale)
|
||||
|
||||
# Run kernel with multi-CTA grid
|
||||
kernel = FmhaKernel(
|
||||
head_dim=hd, s_k=s_k, num_query_heads=n_h, batch_size=batch,
|
||||
use_smem_p=(hd > 64), normalize=True,
|
||||
)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel(q, k, v, o, stream, lse=lse)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o.flatten().float().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
).item()
|
||||
|
||||
status = "PASS" if cos >= 0.99 else "FAIL"
|
||||
print(f' hd={hd}, n_h={n_h}, batch={batch}, T={T}, s_k={s_k}: cos {cos:.6f} {status}')
|
||||
return cos >= 0.99
|
||||
|
||||
|
||||
def test():
|
||||
print("=== D2: Multi-CTA Grid (flat_divide approach) ===\n")
|
||||
|
||||
all_pass = True
|
||||
|
||||
# n_h=1 regression (should match existing single-head behavior)
|
||||
all_pass &= test_multicta(64, 1, 1, 128, 128)
|
||||
|
||||
# n_h=2, single batch
|
||||
all_pass &= test_multicta(64, 2, 1, 128, 128)
|
||||
|
||||
# n_h=4, batch=2
|
||||
all_pass &= test_multicta(64, 4, 2, 128, 128)
|
||||
|
||||
# n_h=8
|
||||
all_pass &= test_multicta(64, 8, 1, 128, 128)
|
||||
|
||||
print(f'\nOverall: {"ALL PASS" if all_pass else "SOME FAILED"}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user