Files
nvfp4-megamoe-kernel/tests/unit/test_production.py
biondizzle b9f15c250f Stage E: head-packed MQA/GQA, batch dim, custom_op, integration API
- production.py: head-packed M dimension for MQA/GQA (q_per_kv*T rows
  in single launch per KV group, eliminating redundant K/V TMA loads)
- production.py: batch dimension support (outer Python loop)
- production.py: warmup_attention_kernels() for pre-compilation
- production.py: dsv4_attention_per_head() for exact per-head sink bias
- __init__.py: sparse_fmha_with_swa, dense_fmha_with_swa, swa_only_fmha
  integration functions bridging AttentionSubBlock → production FMHA
- custom_ops.py: dsv4::sparse_fmha_with_swa custom_op registration
- test_production.py: comprehensive tests (MHA/MQA/GQA, head-packed vs
  per-head parity, multi-segment KV, SWA+causal+sink, batch, edge cases)
2026-05-27 15:15:03 +00:00

363 lines
13 KiB
Python

"""Comprehensive test suite for Stage E production attention.
Tests:
1. MHA / MQA / GQA correctness (head-packed)
2. Batch dimension support
3. Multi-segment KV (Python KV merge)
4. SWA masking + causal + sink bias
5. Per-head launch vs head-packed parity
6. Reference parity against FP32 oracle
7. Custom op registration
8. Edge cases: single token, single head, exact-fit segments
"""
import torch
import math
import pytest
from dsv4.kernels.attention.production import dsv4_attention, dsv4_attention_per_head
# ---------------------------------------------------------------------------
# Reference implementations
# ---------------------------------------------------------------------------
def _pytorch_ref_attention(
q: torch.Tensor, # (n_q, T, hd)
k: torch.Tensor, # (n_kv, N, hd) or (N, hd)
v: torch.Tensor, # same as k
scale: float,
swa_len: int = None,
is_causal: bool = False,
n_comp: int = 0,
sink_bias: torch.Tensor = None, # (n_q,) or scalar
) -> torch.Tensor:
"""Full-precision PyTorch reference with SWA mask, causal, and sink bias."""
n_q, T, hd = q.shape
if k.dim() == 2:
k = k.unsqueeze(0)
v = v.unsqueeze(0)
n_kv, N, _ = k.shape
q_per_kv = n_q // n_kv
ref = torch.zeros(n_q, T, hd, dtype=torch.float32, device='cuda')
for qi in range(n_q):
ki = qi // q_per_kv
qf = q[qi].float() # (T, hd)
kf = k[ki].float() # (N, hd)
vf = v[ki].float() # (N, hd)
# QK^T
attn = qf @ kf.T * scale # (T, N)
# Sink bias: add to SWA positions (>= n_comp)
if sink_bias is not None:
sb = float(sink_bias[qi]) if sink_bias.numel() > 1 else float(sink_bias[0])
for pos in range(N):
if pos >= n_comp:
attn[:, pos] += sb
# SWA mask: mask positions >= n_comp + swa_len
if swa_len is not None:
for pos in range(N):
if pos >= n_comp + swa_len:
attn[:, pos] = float('-inf')
# Causal mask on SWA region
if is_causal:
for t in range(T):
for pos in range(N):
if pos >= n_comp:
swa_pos = pos - n_comp
if swa_pos > t:
attn[t, pos] = float('-inf')
ref[qi] = torch.softmax(attn, dim=-1) @ vf
return ref.bfloat16()
# ---------------------------------------------------------------------------
# Basic MHA / MQA / GQA
# ---------------------------------------------------------------------------
def test_mha_basic():
"""MHA: n_q = n_kv."""
torch.manual_seed(42)
hd = 64; T = 128; N = 256; n_q = 4; n_kv = 4
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
out = dsv4_attention(q, k, v)
ref = _pytorch_ref_attention(q, k, v, 1.0 / math.sqrt(hd))
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
).item()
print(f" MHA n_q={n_q} n_kv={n_kv} N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
assert cos >= 0.99, f"MHA cos={cos}"
def test_mqa_basic():
"""MQA: n_q > 1, n_kv = 1 (shared K/V)."""
torch.manual_seed(42)
hd = 64; T = 128; N = 256; n_q = 8; n_kv = 1
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') # 2D
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
out = dsv4_attention(q, k, v)
ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd))
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
).item()
print(f" MQA n_q={n_q} n_kv=1 N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
assert cos >= 0.99, f"MQA cos={cos}"
def test_gqa_basic():
"""GQA: n_q > n_kv > 1."""
torch.manual_seed(42)
hd = 64; T = 128; N = 256; n_q = 8; n_kv = 2
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
out = dsv4_attention(q, k, v)
ref = _pytorch_ref_attention(q, k, v, 1.0 / math.sqrt(hd))
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
).item()
print(f" GQA n_q={n_q} n_kv={n_kv} N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
assert cos >= 0.99, f"GQA cos={cos}"
# ---------------------------------------------------------------------------
# Head-packed vs per-head parity
# ---------------------------------------------------------------------------
def test_head_packed_vs_per_head():
"""Head-packed and per-head launches should produce identical results (no sink bias)."""
torch.manual_seed(42)
hd = 64; T = 128; N = 256; n_q = 4; n_kv = 1
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
out_packed = dsv4_attention(q, k, v)
out_per_head = dsv4_attention_per_head(q, k, v)
cos = torch.nn.functional.cosine_similarity(
out_packed.flatten().unsqueeze(0), out_per_head.float().flatten().unsqueeze(0)
).item()
max_diff = (out_packed.float() - out_per_head.float()).abs().max().item()
print(f" Packed vs per-head: cos {cos:.6f} max_diff {max_diff:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}")
assert cos >= 0.999, f"Packed vs per-head cos={cos}"
# ---------------------------------------------------------------------------
# Multi-segment KV (Python KV merge)
# ---------------------------------------------------------------------------
def test_multi_segment_kv():
"""N > 128 triggers Python KV merge across segments."""
torch.manual_seed(42)
hd = 64; T = 128; N = 512; n_q = 2
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n_q, N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n_q, N, hd, dtype=torch.bfloat16, device='cuda')
out = dsv4_attention(q, k, v)
ref = _pytorch_ref_attention(q, k, v, 1.0 / math.sqrt(hd))
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
).item()
print(f" Multi-seg N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
assert cos >= 0.99, f"Multi-seg cos={cos}"
# ---------------------------------------------------------------------------
# SWA + causal + sink bias
# ---------------------------------------------------------------------------
def test_swa_causal_sink():
"""SWA masking + causal + sink bias (D3+D4+D5c combined)."""
torch.manual_seed(42)
hd = 64; T = 64; N = 256; n_q = 1
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') # MQA
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
swa_len = 128
n_comp = 64 # first 64 positions are compressed
sink_bias = torch.tensor([0.5], dtype=torch.float32, device='cuda')
out = dsv4_attention(
q, k, v, swa_len=swa_len, is_causal=True, n_comp=n_comp,
sink_bias=sink_bias,
)
ref = _pytorch_ref_attention(
q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd),
swa_len=swa_len, is_causal=True, n_comp=n_comp,
sink_bias=sink_bias,
)
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
).item()
print(f" SWA+causal+sink: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
assert cos >= 0.99, f"SWA+causal+sink cos={cos}"
# ---------------------------------------------------------------------------
# Batch dimension
# ---------------------------------------------------------------------------
def test_batch_dimension():
"""Batch dim: (batch, n_q, T, hd) input/output."""
torch.manual_seed(42)
hd = 64; T = 128; N = 128; n_q = 2; batch = 2
q = torch.randn(batch, n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(batch, n_q, N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(batch, n_q, N, hd, dtype=torch.bfloat16, device='cuda')
out = dsv4_attention(q, k, v)
assert out.shape == q.shape, f"Shape mismatch: {out.shape} vs {q.shape}"
# Verify each batch item individually
for b in range(batch):
ref_b = _pytorch_ref_attention(q[b], k[b], v[b], 1.0 / math.sqrt(hd))
cos = torch.nn.functional.cosine_similarity(
out[b].flatten().unsqueeze(0), ref_b.float().flatten().unsqueeze(0)
).item()
print(f" Batch[{b}]: cos {cos:.6f}")
assert cos >= 0.99, f"Batch[{b}] cos={cos}"
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
def test_single_token():
"""T=1 decode case."""
torch.manual_seed(42)
hd = 64; T = 1; N = 128; n_q = 1
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
out = dsv4_attention(q, k, v)
ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd))
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
).item()
print(f" Single token T=1: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
assert cos >= 0.99, f"Single token cos={cos}"
def test_exact_fit_segment():
"""N exactly equals s_k=128 (single segment, no padding)."""
torch.manual_seed(42)
hd = 64; T = 128; N = 128; n_q = 1
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
out = dsv4_attention(q, k, v)
ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd))
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
).item()
print(f" Exact fit N=128: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
assert cos >= 0.99, f"Exact fit cos={cos}"
def test_partial_segment():
"""N=200 → 2 segments, second segment partially padded."""
torch.manual_seed(42)
hd = 64; T = 128; N = 200; n_q = 1
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
out = dsv4_attention(q, k, v)
ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd))
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
).item()
print(f" Partial seg N=200: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
assert cos >= 0.99, f"Partial segment cos={cos}"
# ---------------------------------------------------------------------------
# Custom op
# ---------------------------------------------------------------------------
def test_custom_op():
"""torch.library.custom_op registration and execution."""
from dsv4.ops.custom_ops import dsv4_sparse_fmha
torch.manual_seed(42)
hd = 64; T = 128; N = 128; n_q = 1
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
sink_bias = torch.zeros(n_q, dtype=torch.float32, device='cuda')
out = dsv4_sparse_fmha(q, k, v, sink_bias, 1.0 / math.sqrt(hd), 0, False, 0)
ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd))
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
).item()
print(f" Custom op: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
assert cos >= 0.99, f"Custom op cos={cos}"
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def test():
print("=" * 60)
print("Stage E: Production DSV4 Attention — Comprehensive Tests")
print("=" * 60)
print("\n--- Basic MHA / MQA / GQA ---")
test_mha_basic()
test_mqa_basic()
test_gqa_basic()
print("\n--- Head-packed vs per-head parity ---")
test_head_packed_vs_per_head()
print("\n--- Multi-segment KV (Python KV merge) ---")
test_multi_segment_kv()
print("\n--- SWA + causal + sink bias ---")
test_swa_causal_sink()
print("\n--- Batch dimension ---")
test_batch_dimension()
print("\n--- Edge cases ---")
test_single_token()
test_exact_fit_segment()
test_partial_segment()
print("\n--- Custom op ---")
test_custom_op()
print("\n" + "=" * 60)
print("ALL TESTS PASSED")
print("=" * 60)
if __name__ == '__main__':
test()