Files
nvfp4-megamoe-kernel/tests/unit/test_p7_multi_row_softmax.py
biondizzle e747742598 P7: Document TMEM column layout, add multi-row softmax test
docs/p7_tmem_column_layout.md: Verified that tcgen05.ld 32x32b.x8 is
the correct instruction for multi-row softmax. Each call reads 8 KV
positions for 32 rows. No instruction change needed from single-row.

test_p7_multi_row_softmax.py: Tests T=1,4,32,64,128 at various HD and N.
Gate: cos >= 0.999996.
2026-05-30 17:17:54 +00:00

101 lines
3.2 KiB
Python

"""
P7 Integration Test: Multi-row softmax T>32.
Verifies the TMEM column layout finding: tcgen05.ld 32x32b.x8 is the correct
instruction for multi-row softmax. Each call reads 8 KV positions for 32 rows.
Gate: worst-case cosine >= 0.999996 per configuration for T in {1, 4, 32, 128}.
"""
import torch
import math
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
def reference_attention_prefill(q, k, v, scale):
"""PyTorch reference for prefill attention.
q: (n_h, T, hd), k: (n_kv, N, hd), v: (n_kv, hd, N) (kernel layout)
Returns: (n_h, T, hd) BF16
"""
n_h, T, hd = q.shape
n_kv = k.shape[0]
N = k.shape[1]
q_per_kv = n_h // n_kv
output = torch.zeros(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
for h in range(n_h):
kv_idx = h // q_per_kv
q_h = q[h] # (T, hd)
k_h = k[kv_idx] # (N, hd)
v_h = v[kv_idx].T # (N, hd) — transpose from kernel layout
s = torch.matmul(q_h.float(), k_h.float().T) * scale # (T, N)
s = torch.softmax(s, dim=-1)
o = torch.matmul(s, v_h.float()) # (T, hd)
output[h] = o.bfloat16()
return output
def test_multi_row_softmax():
"""Test multi-row softmax for T in {1, 4, 32, 128} at various HD."""
torch.manual_seed(42)
configs = [
# (T, hd, N, desc)
(1, 64, 256, "T=1 hd=64 (decode)"),
(1, 128, 256, "T=1 hd=128 (decode)"),
(4, 64, 256, "T=4 hd=64 (small prefill)"),
(4, 128, 256, "T=4 hd=128"),
(32, 64, 256, "T=32 hd=64 (1 warp)"),
(32, 128, 256, "T=32 hd=128"),
(64, 64, 256, "T=64 hd=64 (2 warps)"),
(128, 64, 256, "T=128 hd=64 (4 warps)"),
(128, 128, 256, "T=128 hd=128 (full tile)"),
(128, 64, 512, "T=128 hd=64 N=512 (4 KV tiles)"),
]
all_pass = True
for T, hd, N, desc in configs:
scale = 1.0 / math.sqrt(hd)
n_h = 4
n_kv = 4
q = torch.randn(1, n_h, T, hd, dtype=torch.bfloat16, device='cuda').contiguous()
k = torch.randn(1, n_kv, N, hd, dtype=torch.bfloat16, device='cuda').contiguous()
v = torch.randn(1, n_kv, hd, N, dtype=torch.bfloat16, device='cuda').contiguous()
o_4d, _ = fmha_multitile_decode_raw(q, k, v, scale)
o_kernel = o_4d[0] # (n_h, T, hd)
o_ref = reference_attention_prefill(q[0], k[0], v[0], scale)
worst_cos = 1.0
for h in range(n_h):
cos = torch.nn.functional.cosine_similarity(
o_kernel[h].flatten().float().unsqueeze(0),
o_ref[h].flatten().float().unsqueeze(0)
).item()
worst_cos = min(worst_cos, cos)
status = "PASS" if worst_cos >= 0.999996 else "FAIL"
if worst_cos < 0.999996:
all_pass = False
print(f" {status} {desc}: worst_cos={worst_cos:.6f}")
return all_pass
if __name__ == "__main__":
print("P7 Integration Test: Multi-row softmax (T>32)")
print("=" * 60)
print("TMEM layout: 32x32b.x8 reads 8 KV positions for 32 rows per call")
print()
if test_multi_row_softmax():
print("\nALL PASS")
else:
print("\nSOME FAILED")
sys.exit(1)