117 lines
3.6 KiB
Python
117 lines
3.6 KiB
Python
"""Test production DSV4 attention wrapper."""
|
|
import torch
|
|
import math
|
|
from dsv4.kernels.attention.production import dsv4_attention
|
|
|
|
|
|
def test_production_basic():
|
|
"""Test basic single-head attention."""
|
|
torch.manual_seed(42)
|
|
hd = 64
|
|
n_h = 1
|
|
T = 128
|
|
N = 128
|
|
|
|
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# PyTorch reference (un-normalized)
|
|
qf = q[0].float()
|
|
kf = k.float()
|
|
vf = v.float()
|
|
scale = 1.0 / math.sqrt(hd)
|
|
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
|
|
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
|
|
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
|
|
ref_unnorm = attn_exp @ vf
|
|
ref_norm = (attn_exp / attn_sum) @ vf
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
|
|
cos_unnorm = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref_unnorm.unsqueeze(0).flatten().unsqueeze(0)
|
|
).item()
|
|
cos_norm = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref_norm.unsqueeze(0).flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" hd={hd}, n_h={n_h}, N={N}: cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f}")
|
|
|
|
|
|
def test_production_multi_head():
|
|
"""Test multi-head attention (per-head launch)."""
|
|
torch.manual_seed(42)
|
|
hd = 64
|
|
n_h = 4
|
|
T = 128
|
|
N = 256
|
|
|
|
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# PyTorch reference
|
|
scale = 1.0 / math.sqrt(hd)
|
|
ref = torch.zeros_like(q)
|
|
for h in range(n_h):
|
|
qf = q[h].float()
|
|
kf = k[h].float()
|
|
vf = v[h].float()
|
|
attn = qf @ kf.T * scale
|
|
attn_max = attn.max(dim=-1, keepdim=True)[0]
|
|
attn_exp = torch.exp(attn - attn_max)
|
|
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
|
|
ref[h] = ((attn_exp / attn_sum) @ vf).bfloat16()
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
status = "PASS" if cos >= 0.99 else "FAIL"
|
|
print(f" hd={hd}, n_h={n_h}, N={N}: cos {cos:.6f} {status}")
|
|
|
|
|
|
def test_production_multi_kv():
|
|
"""Test multi-KV-tile with Python KV merge."""
|
|
torch.manual_seed(42)
|
|
hd = 64
|
|
n_h = 1
|
|
T = 128
|
|
N = 256 # 2 KV segments
|
|
|
|
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(n_h, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# PyTorch reference
|
|
scale = 1.0 / math.sqrt(hd)
|
|
qf = q[0].float()
|
|
kf = k[0].float()
|
|
vf = v[0].float()
|
|
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
|
|
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
|
|
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
|
|
ref_norm = (attn_exp / attn_sum) @ vf
|
|
ref_unnorm = attn_exp @ vf
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
|
|
cos_unnorm = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref_unnorm.unsqueeze(0).flatten().unsqueeze(0)
|
|
).item()
|
|
cos_norm = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref_norm.unsqueeze(0).flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" hd={hd}, n_h={n_h}, N={N}: cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f}")
|
|
|
|
|
|
def test():
|
|
print("=== Production DSV4 Attention Wrapper ===\n")
|
|
test_production_basic()
|
|
test_production_multi_kv()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|