diff --git a/tests/unit/test_d2_flat_divide_diag.py b/tests/unit/test_d2_flat_divide_diag.py new file mode 100644 index 00000000..61a3e55c --- /dev/null +++ b/tests/unit/test_d2_flat_divide_diag.py @@ -0,0 +1,188 @@ +""" +D2 diagnostic: Print flat_divide + tma_partition shapes for multi-CTA FMHA. + +Runs inside @cute.kernel to get trace-time shapes. +This tells us exactly how to index the partitioned GMEM tensors +with runtime block_idx coordinates. + +Test with: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_flat_divide_diag.py +""" +import torch, math, cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.utils.sm100 as sm100 +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, const_expr +from cutlass.utils import LayoutEnum +import cuda.bindings.driver as cuda +import cutlass.torch as ct + + +class ShapeDiagKernel: + def __init__(self, hd=64, n_h=1, batch=1, s_k=128): + self.hd = hd; self.n_h = n_h; self.batch = batch; self.s_k = s_k + self.q_dtype = BFloat16 + self.cta_group = tcgen05.CtaGroup.ONE + self.cluster_shape_mn = (1, 1) + + @cute.jit + def __call__(self, q, k, v, stream): + n_h = self.n_h; h_r = n_h; h_k = 1; hd = self.hd + batch = self.batch; s_k_len = self.s_k; T = q.shape[2] + + # CUTLASS-style GMEM tensor layouts + mQ = cute.make_tensor(q.iterator, cute.make_layout( + (T, hd, ((h_r, h_k), batch)), + stride=(hd * h_r * h_k, 1, ((hd, hd * h_r), hd * h_r * h_k * T)), + )) + mK = cute.make_tensor(k.iterator, cute.make_layout( + (s_k_len, hd, ((h_r, h_k), batch)), + stride=(hd * h_k, 1, ((0, hd), hd * h_k * s_k_len)), + )) + mV = cute.make_tensor(v.iterator, cute.make_layout( + (hd, s_k_len, ((h_r, h_k), batch)), + stride=(1, hd * h_k, ((0, hd), hd * h_k * s_k_len)), + )) + + a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() + b_major = LayoutEnum.from_tensor(mK).mma_major_mode() + v_major = LayoutEnum.from_tensor(mV).mma_major_mode() + + qk_mma = sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, a_major, b_major, + Float32, self.cta_group, (128, 128), tcgen05.OperandSource.SMEM, + ) + pv_mma = sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major, + Float32, self.cta_group, (128, min(hd, 256)), tcgen05.OperandSource.TMEM, + ) + + qk_mma_tiler = (128, 128, min(hd, 256)) + pv_n_tile = min(hd, 256) + pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) + pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik)) + + q_smem_s = sm100.make_smem_layout_a(qk_mma, qk_mma_tiler, self.q_dtype, 1) + k_smem_s = sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, self.q_dtype, 1) + v_smem_s = sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, self.q_dtype, 1) + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,) + ) + + tma_q, tma_mQ = cute.nvgpu.make_tiled_tma_atom_A( + sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + mQ, cute.select(q_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape, + ) + tma_k, tma_mK = cute.nvgpu.make_tiled_tma_atom_B( + sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + mK, cute.select(k_smem_s, mode=[0, 1, 2]), qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape, + ) + tma_v, tma_mV = cute.nvgpu.make_tiled_tma_atom_B( + sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + mV, cute.select(v_smem_s, mode=[0, 1, 2]), pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape, + ) + + self._diag(tma_mQ, tma_mK, tma_mV, tma_q, tma_k, tma_v, + qk_mma, pv_mma, cluster_layout_vmnk, + q_smem_s, k_smem_s, v_smem_s, qk_mma_tiler, pv_mma_tiler, + ).launch(grid=(1, 1, 1), block=[192, 1, 1], stream=stream) + + @cute.kernel + def _diag(self, mQ, mK, mV, tma_q, tma_k, tma_v, + qk_mma, pv_mma, cl_vmnk, + q_smem_s, k_smem_s, v_smem_s, qk_mma_tiler, pv_mma_tiler): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 5: + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + + # === Print full GMEM tensor shapes === + print(f"mQ shape: {cute.shape(mQ)}") + print(f"mK shape: {cute.shape(mK)}") + print(f"mV shape: {cute.shape(mV)}") + + # === flat_divide === + gQ = cute.flat_divide(mQ, cute.select(qk_mma_tiler, mode=[0, 2])) + gK = cute.flat_divide(mK, cute.select(qk_mma_tiler, mode=[1, 2])) + gV = cute.flat_divide(mV, cute.select(pv_mma_tiler, mode=[1, 2])) + + print(f"gQ shape: {cute.shape(gQ)}") + print(f"gK shape: {cute.shape(gK)}") + print(f"gV shape: {cute.shape(gV)}") + + # === MMA partition === + tSgQ = qk_thr.partition_A(gQ) + tSgK = qk_thr.partition_B(gK) + tSgV = pv_thr.partition_B(gV) + + print(f"tSgQ shape: {cute.shape(tSgQ)}") + print(f"tSgK shape: {cute.shape(tSgK)}") + print(f"tSgV shape: {cute.shape(tSgV)}") + + # === tma_partition === + sQ = cute.make_tensor(BFloat16, q_smem_s.outer) + sK = cute.make_tensor(BFloat16, k_smem_s.outer) + sV = cute.make_tensor(BFloat16, v_smem_s.outer) + + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, 0, None, 0)).shape) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, None, 0, 0)).shape) + + tQsQ, tQgQ = cpasync.tma_partition( + tma_q, 0, a_lay, + cute.group_modes(sQ, 0, 3), cute.group_modes(tSgQ, 0, 3), + ) + tKsK, tKgK = cpasync.tma_partition( + tma_k, 0, b_lay, + cute.group_modes(sK, 0, 3), cute.group_modes(tSgK, 0, 3), + ) + tVsV, tVgV = cpasync.tma_partition( + tma_v, 0, b_lay, + cute.group_modes(sV, 0, 3), cute.group_modes(tSgV, 0, 3), + ) + + print(f"tQgQ shape: {cute.shape(tQgQ)}") + print(f"tKgK shape: {cute.shape(tKgK)}") + print(f"tVgV shape: {cute.shape(tVgV)}") + + # === Try slicing patterns === + # The original code uses (None,0,None,0) on 4-mode tensors. + # flat_divide produces MORE modes. Let's see what we get. + print(f"tQgQ mode count: {len(cute.shape(tQgQ))}") + print(f"tKgK mode count: {len(cute.shape(tKgK))}") + print(f"tVgV mode count: {len(cute.shape(tVgV))}") + + # Try CUTLASS-style indexing: + # tQgQ[None, None, 0, coord] where coord = (head, h_k, batch) + # But we need to know the mode count to construct the right index. + # Let's try various slicing patterns and see what compiles. + + +def test_shapes(hd=64, n_h=1, batch=1, T=128, s_k=128): + print(f"\n--- hd={hd}, n_h={n_h}, batch={batch}, T={T}, s_k={s_k} ---") + + q = torch.randn(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda') + k = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda') + v = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda') + + q_cute = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + k_cute = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + v_cute = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) + + kernel = ShapeDiagKernel(hd=hd, n_h=n_h, batch=batch, s_k=s_k) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + print(f'Compiling (hd={hd}, n_h={n_h})...', flush=True) + compiled = cute.compile(kernel, q_cute, k_cute, v_cute, stream) + compiled(q_cute, k_cute, v_cute, stream) + torch.cuda.synchronize() + print('Done.') + + +def test(): + print("=== D2 flat_divide shape diagnostic ===") + test_shapes(64, 1, 1, 128, 128) + test_shapes(64, 2, 1, 128, 128) + + +if __name__ == '__main__': + test() diff --git a/tests/unit/test_d2_multicta.py b/tests/unit/test_d2_multicta.py new file mode 100644 index 00000000..ae31dff2 --- /dev/null +++ b/tests/unit/test_d2_multicta.py @@ -0,0 +1,90 @@ +""" +FMHA D2: Multi-CTA grid with flat_divide + tma_partition inside kernel. + +Proper Blackwell approach following CUTLASS reference pattern: +- Q/K/V/O tensors with embedded head dimensions: (s, d, ((h_r, h_k), batch)) +- flat_divide + tma_partition inside warp blocks (runtime block_idx) +- Direct TMA bulk copy for O output (not epilogue_tma_store) +- Grid: (M_tiles, h_q, batch) — one CTA per (M-tile, head, batch) triple + +DSV4 is MQA: h_r = num_query_heads, h_k = 1, K/V shared. +""" +import torch, math +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + + +def reference_attention(q, k, v, scale): + """FP32 reference attention. q: (batch, n_h, T, hd), k/v: (batch, s_k, hd)""" + qf = q.float() + kf = k.float() + vf = v.float() + batch, n_h, T, hd = q.shape + s_k = k.shape[1] + ref = torch.zeros_like(qf) + for b in range(batch): + for h in range(n_h): + attn = qf[b, h] @ kf[b].T * scale + attn_max = attn.max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(attn - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref[b, h] = (attn_exp / attn_sum) @ vf[b] + return ref + + +def test_multicta(hd=64, n_h=1, batch=1, T=128, s_k=128): + """Test multi-CTA grid FMHA with n_h query heads (MQA).""" + torch.manual_seed(42) + scale = 1.0 / math.sqrt(hd) + + q = torch.randn(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda') + k = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda') + v = torch.randn(batch, s_k, hd, dtype=torch.bfloat16, device='cuda') + o = torch.zeros(batch, n_h, T, hd, dtype=torch.bfloat16, device='cuda') + lse = torch.zeros(batch, n_h, T, dtype=torch.float32, device='cuda') + + # FP32 reference + ref = reference_attention(q, k, v, scale) + + # Run kernel with multi-CTA grid + kernel = FmhaKernel( + head_dim=hd, s_k=s_k, num_query_heads=n_h, batch_size=batch, + use_smem_p=(hd > 64), normalize=True, + ) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel(q, k, v, o, stream, lse=lse) + torch.cuda.synchronize() + + cos = torch.nn.functional.cosine_similarity( + o.flatten().float().unsqueeze(0), ref.flatten().unsqueeze(0) + ).item() + + status = "PASS" if cos >= 0.99 else "FAIL" + print(f' hd={hd}, n_h={n_h}, batch={batch}, T={T}, s_k={s_k}: cos {cos:.6f} {status}') + return cos >= 0.99 + + +def test(): + print("=== D2: Multi-CTA Grid (flat_divide approach) ===\n") + + all_pass = True + + # n_h=1 regression (should match existing single-head behavior) + all_pass &= test_multicta(64, 1, 1, 128, 128) + + # n_h=2, single batch + all_pass &= test_multicta(64, 2, 1, 128, 128) + + # n_h=4, batch=2 + all_pass &= test_multicta(64, 4, 2, 128, 128) + + # n_h=8 + all_pass &= test_multicta(64, 8, 1, 128, 128) + + print(f'\nOverall: {"ALL PASS" if all_pass else "SOME FAILED"}') + + +if __name__ == '__main__': + test()