docs/p7_tmem_column_layout.md: Verified that tcgen05.ld 32x32b.x8 is the correct instruction for multi-row softmax. Each call reads 8 KV positions for 32 rows. No instruction change needed from single-row. test_p7_multi_row_softmax.py: Tests T=1,4,32,64,128 at various HD and N. Gate: cos >= 0.999996.
101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
"""
|
|
P7 Integration Test: Multi-row softmax T>32.
|
|
|
|
Verifies the TMEM column layout finding: tcgen05.ld 32x32b.x8 is the correct
|
|
instruction for multi-row softmax. Each call reads 8 KV positions for 32 rows.
|
|
|
|
Gate: worst-case cosine >= 0.999996 per configuration for T in {1, 4, 32, 128}.
|
|
"""
|
|
import torch
|
|
import math
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
|
|
|
|
|
|
def reference_attention_prefill(q, k, v, scale):
|
|
"""PyTorch reference for prefill attention.
|
|
|
|
q: (n_h, T, hd), k: (n_kv, N, hd), v: (n_kv, hd, N) (kernel layout)
|
|
Returns: (n_h, T, hd) BF16
|
|
"""
|
|
n_h, T, hd = q.shape
|
|
n_kv = k.shape[0]
|
|
N = k.shape[1]
|
|
q_per_kv = n_h // n_kv
|
|
output = torch.zeros(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
|
|
for h in range(n_h):
|
|
kv_idx = h // q_per_kv
|
|
q_h = q[h] # (T, hd)
|
|
k_h = k[kv_idx] # (N, hd)
|
|
v_h = v[kv_idx].T # (N, hd) — transpose from kernel layout
|
|
s = torch.matmul(q_h.float(), k_h.float().T) * scale # (T, N)
|
|
s = torch.softmax(s, dim=-1)
|
|
o = torch.matmul(s, v_h.float()) # (T, hd)
|
|
output[h] = o.bfloat16()
|
|
return output
|
|
|
|
|
|
def test_multi_row_softmax():
|
|
"""Test multi-row softmax for T in {1, 4, 32, 128} at various HD."""
|
|
torch.manual_seed(42)
|
|
configs = [
|
|
# (T, hd, N, desc)
|
|
(1, 64, 256, "T=1 hd=64 (decode)"),
|
|
(1, 128, 256, "T=1 hd=128 (decode)"),
|
|
(4, 64, 256, "T=4 hd=64 (small prefill)"),
|
|
(4, 128, 256, "T=4 hd=128"),
|
|
(32, 64, 256, "T=32 hd=64 (1 warp)"),
|
|
(32, 128, 256, "T=32 hd=128"),
|
|
(64, 64, 256, "T=64 hd=64 (2 warps)"),
|
|
(128, 64, 256, "T=128 hd=64 (4 warps)"),
|
|
(128, 128, 256, "T=128 hd=128 (full tile)"),
|
|
(128, 64, 512, "T=128 hd=64 N=512 (4 KV tiles)"),
|
|
]
|
|
|
|
all_pass = True
|
|
for T, hd, N, desc in configs:
|
|
scale = 1.0 / math.sqrt(hd)
|
|
n_h = 4
|
|
n_kv = 4
|
|
|
|
q = torch.randn(1, n_h, T, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
k = torch.randn(1, n_kv, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
v = torch.randn(1, n_kv, hd, N, dtype=torch.bfloat16, device='cuda').contiguous()
|
|
|
|
o_4d, _ = fmha_multitile_decode_raw(q, k, v, scale)
|
|
o_kernel = o_4d[0] # (n_h, T, hd)
|
|
|
|
o_ref = reference_attention_prefill(q[0], k[0], v[0], scale)
|
|
|
|
worst_cos = 1.0
|
|
for h in range(n_h):
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
o_kernel[h].flatten().float().unsqueeze(0),
|
|
o_ref[h].flatten().float().unsqueeze(0)
|
|
).item()
|
|
worst_cos = min(worst_cos, cos)
|
|
|
|
status = "PASS" if worst_cos >= 0.999996 else "FAIL"
|
|
if worst_cos < 0.999996:
|
|
all_pass = False
|
|
print(f" {status} {desc}: worst_cos={worst_cos:.6f}")
|
|
|
|
return all_pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("P7 Integration Test: Multi-row softmax (T>32)")
|
|
print("=" * 60)
|
|
print("TMEM layout: 32x32b.x8 reads 8 KV positions for 32 rows per call")
|
|
print()
|
|
|
|
if test_multi_row_softmax():
|
|
print("\nALL PASS")
|
|
else:
|
|
print("\nSOME FAILED")
|
|
sys.exit(1)
|