Test multi-KV merge (2 segments) separately from multi-head

This commit is contained in:
2026-05-27 06:54:16 +00:00
parent 36a6f07a7e
commit 3a25c7feff

View File

@@ -72,10 +72,44 @@ def test_production_multi_head():
print(f" hd={hd}, n_h={n_h}, N={N}: cos {cos:.6f} {status}")
def test_production_multi_kv():
"""Test multi-KV-tile with Python KV merge."""
torch.manual_seed(42)
hd = 64
n_h = 1
T = 128
N = 256 # 2 KV segments
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
# PyTorch reference
scale = 1.0 / math.sqrt(hd)
qf = q[0].float()
kf = k[0].float()
vf = v[0].float()
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_norm = (attn_exp / attn_sum) @ vf
ref_unnorm = attn_exp @ vf
out = dsv4_attention(q, k, v)
cos_unnorm = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref_unnorm.unsqueeze(0).flatten().unsqueeze(0)
).item()
cos_norm = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref_norm.unsqueeze(0).flatten().unsqueeze(0)
).item()
print(f" hd={hd}, n_h={n_h}, N={N}: cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f}")
def test():
print("=== Production DSV4 Attention Wrapper ===\n")
test_production_basic()
test_production_multi_head()
test_production_multi_kv()
if __name__ == '__main__':