Files
nvfp4-megamoe-kernel/tests/unit/test_d2_regression.py
2026-05-25 17:00:56 +00:00

50 lines
1.7 KiB
Python

"""Quick test: n_h=1 regression after grid changes."""
import torch
import math
import cutlass.cute as cute
import cuda.bindings.driver as cuda
import cutlass.torch as ct
from dsv4.kernels.attention.fmha import FmhaKernel
def test():
print("=== n_h=1 regression (hd=64, s_k=128) ===")
torch.manual_seed(42)
M, s_k, hd = 128, 128, 64
scale = 1.0 / math.sqrt(hd)
q = torch.randn(M, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
o = torch.zeros(M, hd, 1, dtype=torch.bfloat16, device='cuda')
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q_c = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
k_c = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_c = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
o_c = ct.from_dlpack(o).mark_layout_dynamic(leading_dim=ct.get_leading_dim(o))
fmha(q_c, k_c, v_c, o_c, stream)
# Reference
scores = torch.matmul(q[:,:,0].float(), k[:,:,0].float().T) * scale
max_s = scores.max(dim=-1, keepdim=True).values
exp_s = (scores - max_s).exp()
sum_s = exp_s.sum(dim=-1, keepdim=True)
p = exp_s / sum_s
ref = torch.matmul(p, v[:,:,0].float()).to(torch.bfloat16)
cos = torch.nn.functional.cosine_similarity(
o[:,:,0].flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0)
).item()
print(f" cos = {cos:.6f}")
if cos >= 0.99:
print(" ✅ PASS")
else:
print(f" ❌ FAIL")
if __name__ == '__main__':
test()