- Deleted fmha.py (CuTeDSL slow path), FmhaKernel, Python KV merge - Deleted fmha_sm100.cuh, fmha_sm100_tc.cuh, fmha_sm100_launch.cu, fmha_epilogue_sm100.cuh - Moved fmha_qk_verify.cuh to tests/unit/qk_verify_kernel.cuh - Deleted decode_sparse.py, decode_swa.py, kernels/decode/ - Deleted 46 test_d*.py probes, test_smem_*, test_cotiled_*, test_tmem_*, test_smem_p_*, test_ultra_minimal, test_fmha_pv16, test_working_softmax_maybe - Deleted root scratch: debug_linear.py, test_mapping.py, run_router_tests.py - Moved archive/ to archived_plans/code_archive/ - Rewrote production.py: single fast path via 6-warp multi-tile kernel - Added STATUS.md, audit_attention_live.md - Moved NEXT_PRIORITIES*.md to archived_plans/
363 lines
13 KiB
Python
363 lines
13 KiB
Python
"""Comprehensive test suite for Stage E production attention.
|
|
|
|
Tests:
|
|
1. MHA / MQA / GQA correctness (head-packed)
|
|
2. Batch dimension support
|
|
3. Multi-segment KV (Python KV merge)
|
|
4. SWA masking + causal + sink bias
|
|
5. Per-head launch vs head-packed parity
|
|
6. Reference parity against FP32 oracle
|
|
7. Custom op registration
|
|
8. Edge cases: single token, single head, exact-fit segments
|
|
"""
|
|
import torch
|
|
import math
|
|
import pytest
|
|
from dsv4.kernels.attention.production import dsv4_attention, dsv4_attention_per_head # noqa: E402
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Reference implementations
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _pytorch_ref_attention(
|
|
q: torch.Tensor, # (n_q, T, hd)
|
|
k: torch.Tensor, # (n_kv, N, hd) or (N, hd)
|
|
v: torch.Tensor, # same as k
|
|
scale: float,
|
|
swa_len: int = None,
|
|
is_causal: bool = False,
|
|
n_comp: int = 0,
|
|
sink_bias: torch.Tensor = None, # (n_q,) or scalar
|
|
) -> torch.Tensor:
|
|
"""Full-precision PyTorch reference with SWA mask, causal, and sink bias."""
|
|
n_q, T, hd = q.shape
|
|
if k.dim() == 2:
|
|
k = k.unsqueeze(0)
|
|
v = v.unsqueeze(0)
|
|
n_kv, N, _ = k.shape
|
|
q_per_kv = n_q // n_kv
|
|
|
|
ref = torch.zeros(n_q, T, hd, dtype=torch.float32, device='cuda')
|
|
for qi in range(n_q):
|
|
ki = qi // q_per_kv
|
|
qf = q[qi].float() # (T, hd)
|
|
kf = k[ki].float() # (N, hd)
|
|
vf = v[ki].float() # (N, hd)
|
|
|
|
# QK^T
|
|
attn = qf @ kf.T * scale # (T, N)
|
|
|
|
# Sink bias: add to SWA positions (>= n_comp)
|
|
if sink_bias is not None:
|
|
sb = float(sink_bias[qi]) if sink_bias.numel() > 1 else float(sink_bias[0])
|
|
for pos in range(N):
|
|
if pos >= n_comp:
|
|
attn[:, pos] += sb
|
|
|
|
# SWA mask: mask positions >= n_comp + swa_len
|
|
if swa_len is not None:
|
|
for pos in range(N):
|
|
if pos >= n_comp + swa_len:
|
|
attn[:, pos] = float('-inf')
|
|
|
|
# Causal mask on SWA region
|
|
if is_causal:
|
|
for t in range(T):
|
|
for pos in range(N):
|
|
if pos >= n_comp:
|
|
swa_pos = pos - n_comp
|
|
if swa_pos > t:
|
|
attn[t, pos] = float('-inf')
|
|
|
|
ref[qi] = torch.softmax(attn, dim=-1) @ vf
|
|
|
|
return ref.bfloat16()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Basic MHA / MQA / GQA
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_mha_basic():
|
|
"""MHA: n_q = n_kv."""
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 128; N = 256; n_q = 4; n_kv = 4
|
|
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
ref = _pytorch_ref_attention(q, k, v, 1.0 / math.sqrt(hd))
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" MHA n_q={n_q} n_kv={n_kv} N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
|
|
assert cos >= 0.99, f"MHA cos={cos}"
|
|
|
|
|
|
def test_mqa_basic():
|
|
"""MQA: n_q > 1, n_kv = 1 (shared K/V)."""
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 128; N = 256; n_q = 8; n_kv = 1
|
|
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') # 2D
|
|
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd))
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" MQA n_q={n_q} n_kv=1 N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
|
|
assert cos >= 0.99, f"MQA cos={cos}"
|
|
|
|
|
|
def test_gqa_basic():
|
|
"""GQA: n_q > n_kv > 1."""
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 128; N = 256; n_q = 8; n_kv = 2
|
|
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
ref = _pytorch_ref_attention(q, k, v, 1.0 / math.sqrt(hd))
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" GQA n_q={n_q} n_kv={n_kv} N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
|
|
assert cos >= 0.99, f"GQA cos={cos}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Head-packed vs per-head parity
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_head_packed_vs_per_head():
|
|
"""Head-packed and per-head launches should produce identical results (no sink bias)."""
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 128; N = 256; n_q = 4; n_kv = 1
|
|
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
out_packed = dsv4_attention(q, k, v)
|
|
out_per_head = dsv4_attention_per_head(q, k, v)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out_packed.flatten().unsqueeze(0), out_per_head.float().flatten().unsqueeze(0)
|
|
).item()
|
|
max_diff = (out_packed.float() - out_per_head.float()).abs().max().item()
|
|
print(f" Packed vs per-head: cos {cos:.6f} max_diff {max_diff:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}")
|
|
assert cos >= 0.999, f"Packed vs per-head cos={cos}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Multi-segment KV (Python KV merge)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_multi_segment_kv():
|
|
"""N > 128 triggers Python KV merge across segments."""
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 128; N = 512; n_q = 2
|
|
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n_q, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(n_q, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
ref = _pytorch_ref_attention(q, k, v, 1.0 / math.sqrt(hd))
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" Multi-seg N={N}: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
|
|
assert cos >= 0.99, f"Multi-seg cos={cos}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SWA + causal + sink bias
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_swa_causal_sink():
|
|
"""SWA masking + causal + sink bias (D3+D4+D5c combined)."""
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 64; N = 256; n_q = 1
|
|
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda') # MQA
|
|
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
swa_len = 128
|
|
n_comp = 64 # first 64 positions are compressed
|
|
sink_bias = torch.tensor([0.5], dtype=torch.float32, device='cuda')
|
|
|
|
out = dsv4_attention(
|
|
q, k, v, swa_len=swa_len, is_causal=True, n_comp=n_comp,
|
|
sink_bias=sink_bias,
|
|
)
|
|
ref = _pytorch_ref_attention(
|
|
q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd),
|
|
swa_len=swa_len, is_causal=True, n_comp=n_comp,
|
|
sink_bias=sink_bias,
|
|
)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" SWA+causal+sink: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
|
|
assert cos >= 0.99, f"SWA+causal+sink cos={cos}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Batch dimension
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_batch_dimension():
|
|
"""Batch dim: (batch, n_q, T, hd) input/output."""
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 128; N = 128; n_q = 2; batch = 2
|
|
q = torch.randn(batch, n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(batch, n_q, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(batch, n_q, N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
assert out.shape == q.shape, f"Shape mismatch: {out.shape} vs {q.shape}"
|
|
|
|
# Verify each batch item individually
|
|
for b in range(batch):
|
|
ref_b = _pytorch_ref_attention(q[b], k[b], v[b], 1.0 / math.sqrt(hd))
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out[b].flatten().unsqueeze(0), ref_b.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" Batch[{b}]: cos {cos:.6f}")
|
|
assert cos >= 0.99, f"Batch[{b}] cos={cos}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Edge cases
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_single_token():
|
|
"""T=1 decode case."""
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 1; N = 128; n_q = 1
|
|
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd))
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" Single token T=1: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
|
|
assert cos >= 0.99, f"Single token cos={cos}"
|
|
|
|
|
|
def test_exact_fit_segment():
|
|
"""N exactly equals s_k=128 (single segment, no padding)."""
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 128; N = 128; n_q = 1
|
|
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd))
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" Exact fit N=128: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
|
|
assert cos >= 0.99, f"Exact fit cos={cos}"
|
|
|
|
|
|
def test_partial_segment():
|
|
"""N=200 → 2 segments, second segment partially padded."""
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 128; N = 200; n_q = 1
|
|
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
out = dsv4_attention(q, k, v)
|
|
ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd))
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" Partial seg N=200: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
|
|
assert cos >= 0.99, f"Partial segment cos={cos}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Custom op
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_custom_op():
|
|
"""torch.library.custom_op registration and execution."""
|
|
from dsv4.ops.custom_ops import dsv4_sparse_fmha
|
|
|
|
torch.manual_seed(42)
|
|
hd = 64; T = 128; N = 128; n_q = 1
|
|
q = torch.randn(n_q, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(N, hd, dtype=torch.bfloat16, device='cuda')
|
|
sink_bias = torch.zeros(n_q, dtype=torch.float32, device='cuda')
|
|
|
|
out = dsv4_sparse_fmha(q, k, v, sink_bias, 1.0 / math.sqrt(hd), 0, False, 0)
|
|
ref = _pytorch_ref_attention(q, k.unsqueeze(0), v.unsqueeze(0), 1.0 / math.sqrt(hd))
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
out.flatten().unsqueeze(0), ref.float().flatten().unsqueeze(0)
|
|
).item()
|
|
print(f" Custom op: cos {cos:.6f} {'PASS' if cos >= 0.99 else 'FAIL'}")
|
|
assert cos >= 0.99, f"Custom op cos={cos}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test():
|
|
print("=" * 60)
|
|
print("Stage E: Production DSV4 Attention — Comprehensive Tests")
|
|
print("=" * 60)
|
|
|
|
print("\n--- Basic MHA / MQA / GQA ---")
|
|
test_mha_basic()
|
|
test_mqa_basic()
|
|
test_gqa_basic()
|
|
|
|
print("\n--- Head-packed vs per-head parity ---")
|
|
test_head_packed_vs_per_head()
|
|
|
|
print("\n--- Multi-segment KV (Python KV merge) ---")
|
|
test_multi_segment_kv()
|
|
|
|
print("\n--- SWA + causal + sink bias ---")
|
|
test_swa_causal_sink()
|
|
|
|
print("\n--- Batch dimension ---")
|
|
test_batch_dimension()
|
|
|
|
print("\n--- Edge cases ---")
|
|
test_single_token()
|
|
test_exact_fit_segment()
|
|
test_partial_segment()
|
|
|
|
print("\n--- Custom op ---")
|
|
test_custom_op()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("ALL TESTS PASSED")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|