""" P7 Integration Test: Multi-row softmax T>32. Verifies the TMEM column layout finding: tcgen05.ld 32x32b.x8 is the correct instruction for multi-row softmax. Each call reads 8 KV positions for 32 rows. Gate: worst-case cosine >= 0.999996 per configuration for T in {1, 4, 32, 128}. """ import torch import math import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw def reference_attention_prefill(q, k, v, scale): """PyTorch reference for prefill attention. q: (n_h, T, hd), k: (n_kv, N, hd), v: (n_kv, hd, N) (kernel layout) Returns: (n_h, T, hd) BF16 """ n_h, T, hd = q.shape n_kv = k.shape[0] N = k.shape[1] q_per_kv = n_h // n_kv output = torch.zeros(n_h, T, hd, dtype=torch.bfloat16, device='cuda') for h in range(n_h): kv_idx = h // q_per_kv q_h = q[h] # (T, hd) k_h = k[kv_idx] # (N, hd) v_h = v[kv_idx].T # (N, hd) — transpose from kernel layout s = torch.matmul(q_h.float(), k_h.float().T) * scale # (T, N) s = torch.softmax(s, dim=-1) o = torch.matmul(s, v_h.float()) # (T, hd) output[h] = o.bfloat16() return output def test_multi_row_softmax(): """Test multi-row softmax for T in {1, 4, 32, 128} at various HD.""" torch.manual_seed(42) configs = [ # (T, hd, N, desc) (1, 64, 256, "T=1 hd=64 (decode)"), (1, 128, 256, "T=1 hd=128 (decode)"), (4, 64, 256, "T=4 hd=64 (small prefill)"), (4, 128, 256, "T=4 hd=128"), (32, 64, 256, "T=32 hd=64 (1 warp)"), (32, 128, 256, "T=32 hd=128"), (64, 64, 256, "T=64 hd=64 (2 warps)"), (128, 64, 256, "T=128 hd=64 (4 warps)"), (128, 128, 256, "T=128 hd=128 (full tile)"), (128, 64, 512, "T=128 hd=64 N=512 (4 KV tiles)"), ] all_pass = True for T, hd, N, desc in configs: scale = 1.0 / math.sqrt(hd) n_h = 4 n_kv = 4 q = torch.randn(1, n_h, T, hd, dtype=torch.bfloat16, device='cuda').contiguous() k = torch.randn(1, n_kv, N, hd, dtype=torch.bfloat16, device='cuda').contiguous() v = torch.randn(1, n_kv, hd, N, dtype=torch.bfloat16, device='cuda').contiguous() o_4d, _ = fmha_multitile_decode_raw(q, k, v, scale) o_kernel = o_4d[0] # (n_h, T, hd) o_ref = reference_attention_prefill(q[0], k[0], v[0], scale) worst_cos = 1.0 for h in range(n_h): cos = torch.nn.functional.cosine_similarity( o_kernel[h].flatten().float().unsqueeze(0), o_ref[h].flatten().float().unsqueeze(0) ).item() worst_cos = min(worst_cos, cos) status = "PASS" if worst_cos >= 0.999996 else "FAIL" if worst_cos < 0.999996: all_pass = False print(f" {status} {desc}: worst_cos={worst_cos:.6f}") return all_pass if __name__ == "__main__": print("P7 Integration Test: Multi-row softmax (T>32)") print("=" * 60) print("TMEM layout: 32x32b.x8 reads 8 KV positions for 32 rows per call") print() if test_multi_row_softmax(): print("\nALL PASS") else: print("\nSOME FAILED") sys.exit(1)