Files
nvfp4-megamoe-kernel/tests/unit/test_d2_regression.py

62 lines
2.3 KiB
Python

"""
D2: n_h=1 regression test. Identical to D1 but with the D2 test infrastructure.
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test():
hd = 64
s_k = 128
m = 128
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# FP32 reference
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_unnorm = attn_exp @ v.float()
ref_norm = (attn_exp / attn_sum) @ v.float()
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False)
pv_n_tile = kernel.pv_n_tile
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
print(f'Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
compiled(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
out = c_tile[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)
).item()
print(f'hd={hd}, s_k={s_k}: cos_norm {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')
if __name__ == '__main__':
test()