185 lines
6.4 KiB
Python
185 lines
6.4 KiB
Python
"""
|
|
P3 Integration Test: Verify 6-warp multi-head decode fast path
|
|
produces identical results to a PyTorch reference.
|
|
|
|
Tests MHA, MQA, GQA at HD = 64, 128, 256.
|
|
Cosine similarity >= 0.999998 between kernel output and reference.
|
|
"""
|
|
import torch
|
|
import math
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
|
|
|
|
|
|
def cosine_sim(a, b):
|
|
a = a.flatten().float()
|
|
b = b.flatten().float()
|
|
return (a @ b) / (a.norm() * b.norm() + 1e-30)
|
|
|
|
|
|
def test_fast_path():
|
|
"""Test kernel vs PyTorch reference for MHA, MQA, GQA at various HD."""
|
|
torch.manual_seed(42)
|
|
|
|
configs = [
|
|
# (n_q, n_kv, N, hd, desc)
|
|
(4, 4, 64, 64, "MHA hd=64"),
|
|
(4, 4, 128, 64, "MHA hd=64 N=128"),
|
|
(4, 4, 64, 128, "MHA hd=128"),
|
|
(4, 4, 64, 256, "MHA hd=256"),
|
|
(4, 1, 64, 64, "MQA hd=64"),
|
|
(4, 1, 128, 64, "MQA hd=64 N=128"),
|
|
(4, 1, 64, 128, "MQA hd=128"),
|
|
(4, 1, 64, 256, "MQA hd=256"),
|
|
(128, 1, 64, 64, "MQA Pro hd=64"),
|
|
(128, 1, 64, 128, "MQA Pro hd=128"),
|
|
(8, 2, 64, 64, "GQA hd=64"),
|
|
(8, 4, 64, 128, "GQA hd=128"),
|
|
]
|
|
|
|
all_pass = True
|
|
for n_q, n_kv, N, hd, desc in configs:
|
|
scale = 1.0 / math.sqrt(hd)
|
|
q_per_kv = n_q // n_kv
|
|
|
|
try:
|
|
# ---- Create data in KERNEL layout ----
|
|
# Q: (1, n_q, 1, hd) — each head has 1 row of hd elements
|
|
q_4d = torch.randn(1, n_q, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
|
|
# K: (1, n_kv, N, hd)
|
|
k_4d = torch.randn(1, n_kv, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
|
|
# V: (1, n_kv, hd, N) — the KERNEL expects V transposed (hd, N) per head
|
|
v_4d = torch.randn(1, n_kv, hd, N, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
|
|
# ---- Kernel output ----
|
|
sb = torch.zeros(1, n_q, dtype=torch.float32, device='cuda')
|
|
o_4d, lse_4d = fmha_multihead_decode_raw(
|
|
q_4d, k_4d, v_4d, scale, 0, 0, False, sb
|
|
)
|
|
o_kernel = o_4d # (1, n_q, 1, hd)
|
|
|
|
# ---- PyTorch reference using the SAME data ----
|
|
# Q: (n_q, 1, hd), K: (n_kv, N, hd), V: (n_kv, N, hd) — V is TRANSPOSED from kernel layout
|
|
q_ref = q_4d[0] # (n_q, 1, hd)
|
|
k_ref = k_4d[0] # (n_kv, N, hd)
|
|
v_ref = v_4d[0].transpose(-1, -2) # (n_kv, N, hd) — transpose (hd,N) -> (N,hd)
|
|
|
|
o_ref = torch.zeros(n_q, 1, hd, dtype=torch.bfloat16, device='cuda')
|
|
for kv_idx in range(n_kv):
|
|
k_h = k_ref[kv_idx] # (N, hd)
|
|
v_h = v_ref[kv_idx] # (N, hd)
|
|
for qi in range(q_per_kv):
|
|
q_idx = kv_idx * q_per_kv + qi
|
|
q_h = q_ref[q_idx] # (1, hd)
|
|
s = torch.matmul(q_h.float(), k_h.float().T) * scale # (1, N)
|
|
s = torch.softmax(s, dim=-1)
|
|
o = torch.matmul(s, v_h.float()) # (1, hd)
|
|
o_ref[q_idx] = o.bfloat16()
|
|
|
|
# ---- Compare per-head ----
|
|
worst_cos = 1.0
|
|
for h in range(n_q):
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
o_kernel[0, h].float().flatten().unsqueeze(0),
|
|
o_ref[h].float().flatten().unsqueeze(0),
|
|
).item()
|
|
worst_cos = min(worst_cos, cos)
|
|
|
|
status = "PASS" if worst_cos >= 0.999990 else "FAIL"
|
|
if status == "FAIL":
|
|
all_pass = False
|
|
print(f" {status} {desc}: worst_cos={worst_cos:.6f}")
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
print(f" FAIL {desc}: {e}")
|
|
traceback.print_exc()
|
|
all_pass = False
|
|
|
|
return all_pass
|
|
|
|
|
|
def test_full_api():
|
|
"""Test the full dsv4_attention API (goes through fast path for T=1, N<=128)."""
|
|
from dsv4.kernels.attention.production import dsv4_attention
|
|
|
|
torch.manual_seed(99)
|
|
|
|
configs = [
|
|
(8, 8, 128, 64, "MHA hd=64"),
|
|
(8, 8, 128, 128, "MHA hd=128"),
|
|
(8, 1, 128, 64, "MQA hd=64"),
|
|
(8, 1, 128, 128, "MQA hd=128"),
|
|
(8, 2, 128, 64, "GQA hd=64"),
|
|
]
|
|
|
|
all_pass = True
|
|
for n_q, n_kv, N, hd, desc in configs:
|
|
scale = 1.0 / math.sqrt(hd)
|
|
try:
|
|
q = torch.randn(n_q, 1, hd, dtype=torch.bfloat16, device='cuda')
|
|
if n_kv == 1:
|
|
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
else:
|
|
k = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# Full API call (should use fast path)
|
|
o_fast = dsv4_attention(q, k, v, scale=scale)
|
|
|
|
# Reference
|
|
o_ref = reference_attention_api(q, k, v, scale, n_q, n_kv, N, hd)
|
|
|
|
cos = cosine_sim(o_ref, o_fast).item()
|
|
status = "PASS" if cos >= 0.999990 else "FAIL"
|
|
if status == "FAIL":
|
|
all_pass = False
|
|
print(f" {status} [full API] {desc}: cos={cos:.6f}")
|
|
except Exception as e:
|
|
import traceback
|
|
print(f" FAIL [full API] {desc}: {e}")
|
|
traceback.print_exc()
|
|
all_pass = False
|
|
|
|
return all_pass
|
|
|
|
|
|
def reference_attention_api(q, k, v, scale, n_q, n_kv, N, hd):
|
|
"""Reference that matches dsv4_attention input format."""
|
|
q_per_kv = n_q // n_kv
|
|
if k.dim() == 2:
|
|
k = k.unsqueeze(0)
|
|
if v.dim() == 2:
|
|
v = v.unsqueeze(0)
|
|
output = torch.zeros(n_q, 1, hd, dtype=torch.bfloat16, device='cuda')
|
|
for kv_idx in range(n_kv):
|
|
k_h = k[kv_idx] # (N, hd)
|
|
v_h = v[kv_idx] # (N, hd)
|
|
for qi in range(q_per_kv):
|
|
q_idx = kv_idx * q_per_kv + qi
|
|
q_h = q[q_idx] # (1, hd)
|
|
s = torch.matmul(q_h.float(), k_h.float().T) * scale
|
|
s = torch.softmax(s, dim=-1)
|
|
o = torch.matmul(s, v_h.float())
|
|
output[q_idx] = o.bfloat16()
|
|
return output
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("P3 Integration Test: 6-warp decode fast path")
|
|
print("=" * 60)
|
|
ok1 = test_fast_path()
|
|
print()
|
|
ok2 = test_full_api()
|
|
print("=" * 60)
|
|
ok = ok1 and ok2
|
|
print("ALL PASS" if ok else "SOME FAILED")
|
|
sys.exit(0 if ok else 1)
|