Files
nvfp4-megamoe-kernel/tests/unit/test_p3_fast_decode.py

190 lines
6.5 KiB
Python

"""
P3 Integration Test: 6-warp multi-head decode fast path.
Verifies the kernel produces identical results to a PyTorch reference
for MHA, MQA, and GQA at HD = 64, 128, 256.
Gate: worst-case cosine >= 0.999990 per configuration.
"""
import torch
import math
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
def cosine_sim(a, b):
a = a.flatten().float()
b = b.flatten().float()
return (a @ b) / (a.norm() * b.norm() + 1e-30)
def reference_attention(q_4d, k_4d, v_4d, scale):
"""PyTorch reference matching kernel tensor layout.
Q: (1, n_h, 1, hd), K: (1, n_kv, N, hd), V: (1, n_kv, hd, N)
V is in kernel layout (hd, N) — transpose to (N, hd) for reference.
For MQA/GQA, each Q head uses its corresponding KV head.
"""
n_h = q_4d.shape[1]
n_kv = k_4d.shape[1]
N = k_4d.shape[2]
q_per_kv = n_h // n_kv
q = q_4d[0] # (n_h, 1, hd)
k = k_4d[0] # (n_kv, N, hd)
v = v_4d[0].transpose(-1, -2) # (n_kv, N, hd)
output = torch.zeros(n_h, 1, q_4d.shape[3], dtype=torch.bfloat16, device='cuda')
for h in range(n_h):
kv_idx = h // q_per_kv
q_h = q[h] # (1, hd)
k_h = k[kv_idx] # (N, hd)
v_h = v[kv_idx] # (N, hd)
s = torch.matmul(q_h.float(), k_h.float().T) * scale
s = torch.softmax(s, dim=-1)
o = torch.matmul(s, v_h.float())
output[h] = o.bfloat16()
return output
def test_kernel_correctness():
"""Test kernel vs PyTorch reference for MHA, MQA, GQA at various HD."""
torch.manual_seed(42)
configs = [
# (n_q, n_kv, N, hd, desc)
(4, 4, 64, 64, "MHA hd=64"),
(4, 4, 128, 64, "MHA hd=64 N=128"),
(4, 4, 64, 128, "MHA hd=128"),
(4, 4, 64, 256, "MHA hd=256"),
(4, 1, 64, 64, "MQA hd=64"),
(4, 1, 128, 64, "MQA hd=64 N=128"),
(4, 1, 64, 128, "MQA hd=128"),
(4, 1, 64, 256, "MQA hd=256"),
(128, 1, 64, 64, "MQA Pro hd=64"),
(128, 1, 64, 128, "MQA Pro hd=128"),
(8, 2, 64, 64, "GQA hd=64"),
(8, 4, 64, 128, "GQA hd=128"),
# P5: Multi-KV-tile (N > 128) — uses TMA multi-tile kernel
(4, 4, 256, 64, "MHA hd=64 N=256 (2 tiles)"),
(4, 4, 512, 64, "MHA hd=64 N=512 (4 tiles)"),
(4, 1, 256, 64, "MQA hd=64 N=256 (2 tiles)"),
(4, 1, 512, 64, "MQA hd=64 N=512 (4 tiles)"),
(4, 1, 256, 128, "MQA hd=128 N=256 (2 tiles)"),
(128, 1, 256, 64, "MQA Pro N=256 (2 tiles)"),
]
all_pass = True
for n_q, n_kv, N, hd, desc in configs:
scale = 1.0 / math.sqrt(hd)
try:
q_4d = torch.randn(1, n_q, 1, hd, dtype=torch.bfloat16, device='cuda').contiguous()
k_4d = torch.randn(1, n_kv, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
v_4d = torch.randn(1, n_kv, hd, N, dtype=torch.bfloat16, device='cuda').contiguous()
# Use the correct kernel for the KV size
if N > 128 or hd == 512:
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw as kernel_fn
o_4d, _ = kernel_fn(q_4d, k_4d, v_4d, scale)
else:
sb = torch.zeros(1, n_q, dtype=torch.float32, device='cuda')
o_4d, _ = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
o_ref = reference_attention(q_4d, k_4d, v_4d, scale)
worst_cos = 1.0
for h in range(n_q):
cos = torch.nn.functional.cosine_similarity(
o_4d[0, h].float().flatten().unsqueeze(0),
o_ref[h].float().flatten().unsqueeze(0),
).item()
worst_cos = min(worst_cos, cos)
status = "PASS" if worst_cos >= 0.999990 else "FAIL"
if status == "FAIL":
all_pass = False
print(f" {status} {desc}: worst_cos={worst_cos:.6f}")
except Exception as e:
import traceback
print(f" FAIL {desc}: {e}")
traceback.print_exc()
all_pass = False
return all_pass
def test_full_api():
"""Test the full dsv4_attention API (fast path for T=1, N<=128)."""
from dsv4.kernels.attention.production import dsv4_attention
torch.manual_seed(99)
configs = [
(8, 8, 128, 64, "MHA hd=64"),
(8, 8, 128, 128, "MHA hd=128"),
(8, 1, 128, 64, "MQA hd=64"),
(8, 1, 128, 128, "MQA hd=128"),
(8, 2, 128, 64, "GQA hd=64"),
]
all_pass = True
for n_q, n_kv, N, hd, desc in configs:
scale = 1.0 / math.sqrt(hd)
try:
q = torch.randn(n_q, 1, hd, dtype=torch.bfloat16, device='cuda')
if n_kv == 1:
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
else:
k = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
o_fast = dsv4_attention(q, k, v, scale=scale)
# Reference using same data
if n_kv == 1:
k = k.unsqueeze(0)
v = v.unsqueeze(0)
q_per_kv = n_q // n_kv
o_ref = torch.zeros(n_q, 1, hd, dtype=torch.bfloat16, device='cuda')
for kv_idx in range(n_kv):
k_h = k[kv_idx]
v_h = v[kv_idx]
for qi in range(q_per_kv):
q_idx = kv_idx * q_per_kv + qi
q_h = q[q_idx]
s = torch.matmul(q_h.float(), k_h.float().T) * scale
s = torch.softmax(s, dim=-1)
o = torch.matmul(s, v_h.float())
o_ref[q_idx] = o.bfloat16()
cos = cosine_sim(o_ref, o_fast).item()
status = "PASS" if cos >= 0.999990 else "FAIL"
if status == "FAIL":
all_pass = False
print(f" {status} [API] {desc}: cos={cos:.6f}")
except Exception as e:
import traceback
print(f" FAIL [API] {desc}: {e}")
traceback.print_exc()
all_pass = False
return all_pass
if __name__ == "__main__":
print("P3 Integration Test: 6-warp decode fast path")
print("=" * 60)
ok1 = test_kernel_correctness()
print()
ok2 = test_full_api()
print("=" * 60)
ok = ok1 and ok2
print("ALL PASS" if ok else "SOME FAILED")
sys.exit(0 if ok else 1)