"""Comprehensive test suite for Stage E production attention. Tests: 1. MHA / MQA / GQA correctness (head-packed) 2. Batch dimension support 3. Multi-segment KV (Python KV merge) 4. SWA masking + causal + sink bias 5. Per-head launch vs head-packed parity 6. Reference parity against FP32 oracle 7. Custom op registration 8. Edge cases: single token, single head, exact-fit segments """ import torch import math import pytest from dsv4.kernels.attention.production import dsv4_attention, dsv4_attention_per_head # noqa: E402 # --------------------------------------------------------------------------- # Reference implementations # --------------------------------------------------------------------------- def _pytorch_ref_attention( q: torch.Tensor, # (n_q, T, hd) k: torch.Tensor, # (n_kv, N, hd) or (N, hd) v: torch.Tensor, # same as k scale: float, swa_len: int = None, is_causal: bool = False, n_comp: int = 0, sink_bias: torch.Tensor = None, # (n_q,) or scalar ) -> torch.Tensor: """Full-precision PyTorch reference with SWA mask, causal, and sink bias.""" n_q, T, hd = q.shape if k.dim() == 2: k = k.unsqueeze(0) v = v.unsqueeze(0) n_kv, N, _ = k.shape q_per_kv = n_q // n_kv ref = torch.zeros(n_q, T, hd, dtype=torch.float32, device='cuda') for qi in range(n_q): ki = qi // q_per_kv qf = q[qi].float() # (T, hd) kf = k[ki].float() # (N, hd) vf = v[ki].float() # (N, hd) # QK^T attn = qf @ kf.T * scale # (T, N) # Sink bias: add to SWA positions (>= n_comp) if sink_bias is not None: sb = float(sink_bias[qi]) if sink_bias.numel() > 1 else float(sink_bias[0]) for pos in range(N): if pos >= n_comp: attn[:, pos] += sb # SWA mask: mask positions >= n_comp + swa_len if swa_len is not None: for pos in range(N): if pos >= n_comp + swa_len: attn[:, pos] = float('-inf') # Causal mask on SWA region if is_causal: for t in range(T): for pos in range(N): if pos >= n_comp: swa_pos = pos - n_comp if swa_pos > t: attn[t, pos] = float('-inf') ref[qi] = torch.softmax(attn, dim=-1) @ vf return ref.bfloat16() # --------------------------------------------------------------------------- # Basic MHA / MQA / GQA # --------------------------------------------------------------------------- def test_mha_basic(): """MHA: n_q = n_kv.""" torch.manual_seed(42) hd = 64; T = 128; N = 256; n_q = 4; n_kv = 4 q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda') out = dsv4_attention(q, k, v) ref = _pytorch_ref_attention(q, k, v, 1.0 / math.sqrt(hd)) cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0) ).item() print(f" MHA n_q={n_q} n_kv={n_kv} N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}") assert cos >= 0.99, f"MHA cos={cos}" def test_mqa_basic(): """MQA: n_q > 1, n_kv = 1 (shared K/V).""" torch.manual_seed(42) hd = 64; T = 128; N = 256; n_q = 8; n_kv = 1 q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') # 2D v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') out = dsv4_attention(q, k, v) ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd)) cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0) ).item() print(f" MQA n_q={n_q} n_kv=1 N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}") assert cos >= 0.99, f"MQA cos={cos}" def test_gqa_basic(): """GQA: n_q > n_kv > 1.""" torch.manual_seed(42) hd = 64; T = 128; N = 256; n_q = 8; n_kv = 2 q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda') out = dsv4_attention(q, k, v) ref = _pytorch_ref_attention(q, k, v, 1.0 / math.sqrt(hd)) cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0) ).item() print(f" GQA n_q={n_q} n_kv={n_kv} N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}") assert cos >= 0.99, f"GQA cos={cos}" # --------------------------------------------------------------------------- # Head-packed vs per-head parity # --------------------------------------------------------------------------- def test_head_packed_vs_per_head(): """Head-packed and per-head launches should produce identical results (no sink bias).""" torch.manual_seed(42) hd = 64; T = 128; N = 256; n_q = 4; n_kv = 1 q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') out_packed = dsv4_attention(q, k, v) out_per_head = dsv4_attention_per_head(q, k, v) cos = torch.nn.functional.cosine_similarity( out_packed.flatten().unsqueeze(0), out_per_head.float().flatten().unsqueeze(0) ).item() max_diff = (out_packed.float() - out_per_head.float()).abs().max().item() print(f" Packed vs per-head: cos {cos:.6f} max_diff {max_diff:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}") assert cos >= 0.999, f"Packed vs per-head cos={cos}" # --------------------------------------------------------------------------- # Multi-segment KV (Python KV merge) # --------------------------------------------------------------------------- def test_multi_segment_kv(): """N > 128 triggers Python KV merge across segments.""" torch.manual_seed(42) hd = 64; T = 128; N = 512; n_q = 2 q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(n_q, N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(n_q, N, hd, dtype=torch.bfloat16, device='cuda') out = dsv4_attention(q, k, v) ref = _pytorch_ref_attention(q, k, v, 1.0 / math.sqrt(hd)) cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0) ).item() print(f" Multi-seg N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}") assert cos >= 0.99, f"Multi-seg cos={cos}" # --------------------------------------------------------------------------- # SWA + causal + sink bias # --------------------------------------------------------------------------- def test_swa_causal_sink(): """SWA masking + causal + sink bias (D3+D4+D5c combined).""" torch.manual_seed(42) hd = 64; T = 64; N = 256; n_q = 1 q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') # MQA v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') swa_len = 128 n_comp = 64 # first 64 positions are compressed sink_bias = torch.tensor([0.5], dtype=torch.float32, device='cuda') out = dsv4_attention( q, k, v, swa_len=swa_len, is_causal=True, n_comp=n_comp, sink_bias=sink_bias, ) ref = _pytorch_ref_attention( q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd), swa_len=swa_len, is_causal=True, n_comp=n_comp, sink_bias=sink_bias, ) cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0) ).item() print(f" SWA+causal+sink: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}") assert cos >= 0.99, f"SWA+causal+sink cos={cos}" # --------------------------------------------------------------------------- # Batch dimension # --------------------------------------------------------------------------- def test_batch_dimension(): """Batch dim: (batch, n_q, T, hd) input/output.""" torch.manual_seed(42) hd = 64; T = 128; N = 128; n_q = 2; batch = 2 q = torch.randn(batch, n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(batch, n_q, N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(batch, n_q, N, hd, dtype=torch.bfloat16, device='cuda') out = dsv4_attention(q, k, v) assert out.shape == q.shape, f"Shape mismatch: {out.shape} vs {q.shape}" # Verify each batch item individually for b in range(batch): ref_b = _pytorch_ref_attention(q[b], k[b], v[b], 1.0 / math.sqrt(hd)) cos = torch.nn.functional.cosine_similarity( out[b].flatten().unsqueeze(0), ref_b.float().flatten().unsqueeze(0) ).item() print(f" Batch[{b}]: cos {cos:.6f}") assert cos >= 0.99, f"Batch[{b}] cos={cos}" # --------------------------------------------------------------------------- # Edge cases # --------------------------------------------------------------------------- def test_single_token(): """T=1 decode case.""" torch.manual_seed(42) hd = 64; T = 1; N = 128; n_q = 1 q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') out = dsv4_attention(q, k, v) ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd)) cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0) ).item() print(f" Single token T=1: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}") assert cos >= 0.99, f"Single token cos={cos}" def test_exact_fit_segment(): """N exactly equals s_k=128 (single segment, no padding).""" torch.manual_seed(42) hd = 64; T = 128; N = 128; n_q = 1 q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') out = dsv4_attention(q, k, v) ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd)) cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0) ).item() print(f" Exact fit N=128: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}") assert cos >= 0.99, f"Exact fit cos={cos}" def test_partial_segment(): """N=200 → 2 segments, second segment partially padded.""" torch.manual_seed(42) hd = 64; T = 128; N = 200; n_q = 1 q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') out = dsv4_attention(q, k, v) ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd)) cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0) ).item() print(f" Partial seg N=200: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}") assert cos >= 0.99, f"Partial segment cos={cos}" # --------------------------------------------------------------------------- # Custom op # --------------------------------------------------------------------------- def test_custom_op(): """torch.library.custom_op registration and execution.""" from dsv4.ops.custom_ops import dsv4_sparse_fmha torch.manual_seed(42) hd = 64; T = 128; N = 128; n_q = 1 q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda') k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') sink_bias = torch.zeros(n_q, dtype=torch.float32, device='cuda') out = dsv4_sparse_fmha(q, k, v, sink_bias, 1.0 / math.sqrt(hd), 0, False, 0) ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd)) cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0) ).item() print(f" Custom op: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}") assert cos >= 0.99, f"Custom op cos={cos}" # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def test(): print("=" * 60) print("Stage E: Production DSV4 Attention — Comprehensive Tests") print("=" * 60) print("\n--- Basic MHA / MQA / GQA ---") test_mha_basic() test_mqa_basic() test_gqa_basic() print("\n--- Head-packed vs per-head parity ---") test_head_packed_vs_per_head() print("\n--- Multi-segment KV (Python KV merge) ---") test_multi_segment_kv() print("\n--- SWA + causal + sink bias ---") test_swa_causal_sink() print("\n--- Batch dimension ---") test_batch_dimension() print("\n--- Edge cases ---") test_single_token() test_exact_fit_segment() test_partial_segment() print("\n--- Custom op ---") test_custom_op() print("\n" + "=" * 60) print("ALL TESTS PASSED") print("=" * 60) if __name__ == '__main__': test()