diff --git a/tests/unit/test_production.py b/tests/unit/test_production.py index 4f3690b9..dc6d62b0 100644 --- a/tests/unit/test_production.py +++ b/tests/unit/test_production.py @@ -72,10 +72,44 @@ def test_production_multi_head(): print(f" hd={hd}, n_h={n_h}, N={N}: cos {cos:.6f} {status}") +def test_production_multi_kv(): + """Test multi-KV-tile with Python KV merge.""" + torch.manual_seed(42) + hd = 64 + n_h = 1 + T = 128 + N = 256 # 2 KV segments + + q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda') + + # PyTorch reference + scale = 1.0 / math.sqrt(hd) + qf = q[0].float() + kf = k[0].float() + vf = v[0].float() + attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(qf @ kf.T * scale - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref_norm = (attn_exp / attn_sum) @ vf + ref_unnorm = attn_exp @ vf + + out = dsv4_attention(q, k, v) + + cos_unnorm = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref_unnorm.unsqueeze(0).flatten().unsqueeze(0) + ).item() + cos_norm = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref_norm.unsqueeze(0).flatten().unsqueeze(0) + ).item() + print(f" hd={hd}, n_h={n_h}, N={N}: cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f}") + + def test(): print("=== Production DSV4 Attention Wrapper ===\n") test_production_basic() - test_production_multi_head() + test_production_multi_kv() if __name__ == '__main__':