Files
nvfp4-megamoe-kernel/tests/test_diag_multitile.py

31 lines
1.7 KiB
Python

"""Test the identity diag for multi-tile n=256,384"""
import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline, cutlass.torch as ct, cuda.bindings.driver as cuda
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean
from cutlass.utils import LayoutEnum
from test_fmha_v3_diag import FmhaV3Diag
HEAD_DIM = 64
for n in [128, 256, 384]:
torch.manual_seed(42)
q = torch.randn(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
v = torch.ones(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaV3Diag(s_k=n)
print(f'n={n}: Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
qf = q[:,:,0].float(); kf = k[:,:,0].float()
ref = (qf @ kf.T * (1.0/math.sqrt(HEAD_DIM))) @ v.float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'Identity diag n={n}: cos {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')