""" P3 Integration Test: 6-warp multi-head decode fast path. Verifies the kernel produces identical results to a PyTorch reference for MHA, MQA, and GQA at HD = 64, 128, 256. Gate: worst-case cosine >= 0.999990 per configuration. """ import torch import math import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw def cosine_sim(a, b): a = a.flatten().float() b = b.flatten().float() return (a @ b) / (a.norm() * b.norm() + 1e-30) def reference_attention(q_4d, k_4d, v_4d, scale): """PyTorch reference matching kernel tensor layout. Q: (1, n_h, 1, hd), K: (1, n_kv, N, hd), V: (1, n_kv, hd, N) V is in kernel layout (hd, N) — transpose to (N, hd) for reference. For MQA/GQA, each Q head uses its corresponding KV head. """ n_h = q_4d.shape[1] n_kv = k_4d.shape[1] N = k_4d.shape[2] q_per_kv = n_h // n_kv q = q_4d[0] # (n_h, 1, hd) k = k_4d[0] # (n_kv, N, hd) v = v_4d[0].transpose(-1, -2) # (n_kv, N, hd) output = torch.zeros(n_h, 1, q_4d.shape[3], dtype=torch.bfloat16, device='cuda') for h in range(n_h): kv_idx = h // q_per_kv q_h = q[h] # (1, hd) k_h = k[kv_idx] # (N, hd) v_h = v[kv_idx] # (N, hd) s = torch.matmul(q_h.float(), k_h.float().T) * scale s = torch.softmax(s, dim=-1) o = torch.matmul(s, v_h.float()) output[h] = o.bfloat16() return output def test_kernel_correctness(): """Test kernel vs PyTorch reference for MHA, MQA, GQA at various HD.""" torch.manual_seed(42) configs = [ # (n_q, n_kv, N, hd, desc) (4, 4, 64, 64, "MHA hd=64"), (4, 4, 128, 64, "MHA hd=64 N=128"), (4, 4, 64, 128, "MHA hd=128"), (4, 4, 64, 256, "MHA hd=256"), (4, 1, 64, 64, "MQA hd=64"), (4, 1, 128, 64, "MQA hd=64 N=128"), (4, 1, 64, 128, "MQA hd=128"), (4, 1, 64, 256, "MQA hd=256"), (128, 1, 64, 64, "MQA Pro hd=64"), (128, 1, 64, 128, "MQA Pro hd=128"), (8, 2, 64, 64, "GQA hd=64"), (8, 4, 64, 128, "GQA hd=128"), # P5: Multi-KV-tile (N > 128) — uses TMA multi-tile kernel (4, 4, 256, 64, "MHA hd=64 N=256 (2 tiles)"), (4, 4, 512, 64, "MHA hd=64 N=512 (4 tiles)"), (4, 1, 256, 64, "MQA hd=64 N=256 (2 tiles)"), (4, 1, 512, 64, "MQA hd=64 N=512 (4 tiles)"), (4, 1, 256, 128, "MQA hd=128 N=256 (2 tiles)"), (128, 1, 256, 64, "MQA Pro N=256 (2 tiles)"), ] all_pass = True for n_q, n_kv, N, hd, desc in configs: scale = 1.0 / math.sqrt(hd) try: q_4d = torch.randn(1, n_q, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous() k_4d = torch.randn(1, n_kv, N, hd, dtype=torch.bfloat16, device='cuda').contiguous() v_4d = torch.randn(1, n_kv, hd, N, dtype=torch.bfloat16, device='cuda').contiguous() # Use the correct kernel for the KV size if N > 128 or hd == 512: from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw as kernel_fn o_4d, _ = kernel_fn(q_4d, k_4d, v_4d, scale) else: sb = torch.zeros(1, n_q, dtype=torch.float32, device='cuda') o_4d, _ = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale) o_ref = reference_attention(q_4d, k_4d, v_4d, scale) worst_cos = 1.0 for h in range(n_q): cos = torch.nn.functional.cosine_similarity( o_4d[0, h].float().flatten().unsqueeze(0), o_ref[h].float().flatten().unsqueeze(0), ).item() worst_cos = min(worst_cos, cos) status = "PASS" if worst_cos >= 0.999990 else "FAIL" if status == "FAIL": all_pass = False print(f" {status} {desc}: worst_cos={worst_cos:.6f}") except Exception as e: import traceback print(f" FAIL {desc}: {e}") traceback.print_exc() all_pass = False return all_pass def test_full_api(): """Test the full dsv4_attention API (fast path for T=1, N<=128).""" from dsv4.kernels.attention.production import dsv4_attention torch.manual_seed(99) configs = [ (8, 8, 128, 64, "MHA hd=64"), (8, 8, 128, 128, "MHA hd=128"), (8, 1, 128, 64, "MQA hd=64"), (8, 1, 128, 128, "MQA hd=128"), (8, 2, 128, 64, "GQA hd=64"), ] all_pass = True for n_q, n_kv, N, hd, desc in configs: scale = 1.0 / math.sqrt(hd) try: q = torch.randn(n_q, 1, hd, dtype=torch.bfloat16, device='cuda') if n_kv == 1: k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') else: k = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda') o_fast = dsv4_attention(q, k, v, scale=scale) # Reference using same data if n_kv == 1: k = k.unsqueeze(0) v = v.unsqueeze(0) q_per_kv = n_q // n_kv o_ref = torch.zeros(n_q, 1, hd, dtype=torch.bfloat16, device='cuda') for kv_idx in range(n_kv): k_h = k[kv_idx] v_h = v[kv_idx] for qi in range(q_per_kv): q_idx = kv_idx * q_per_kv + qi q_h = q[q_idx] s = torch.matmul(q_h.float(), k_h.float().T) * scale s = torch.softmax(s, dim=-1) o = torch.matmul(s, v_h.float()) o_ref[q_idx] = o.bfloat16() cos = cosine_sim(o_ref, o_fast).item() status = "PASS" if cos >= 0.999990 else "FAIL" if status == "FAIL": all_pass = False print(f" {status} [API] {desc}: cos={cos:.6f}") except Exception as e: import traceback print(f" FAIL [API] {desc}: {e}") traceback.print_exc() all_pass = False return all_pass if __name__ == "__main__": print("P3 Integration Test: 6-warp decode fast path") print("=" * 60) ok1 = test_kernel_correctness() print() ok2 = test_full_api() print("=" * 60) ok = ok1 and ok2 print("ALL PASS" if ok else "SOME FAILED") sys.exit(0 if ok else 1)