"""Test the identity diag for multi-tile n=256,384""" import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline, cutlass.torch as ct, cuda.bindings.driver as cuda from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass import Float32, BFloat16, Int32, Boolean from cutlass.utils import LayoutEnum from test_fmha_v3_diag import FmhaV3Diag HEAD_DIM = 64 for n in [128, 256, 384]: torch.manual_seed(42) q = torch.randn(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') v = torch.ones(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda') v_kernel = v.unsqueeze(-1) c = torch.zeros(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) kernel = FmhaV3Diag(s_k=n) print(f'n={n}: Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) compiled(mQ, mK, mV, mC, stream) torch.cuda.synchronize() out = c[:,:,0].float() qf = q[:,:,0].float(); kf = k[:,:,0].float() ref = (qf @ kf.T * (1.0/math.sqrt(HEAD_DIM))) @ v.float() cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() print(f'Identity diag n={n}: cos {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')