diff --git a/tests/unit/test_d5c_diag_ncomp96.py b/tests/unit/test_d5c_diag_ncomp96.py deleted file mode 100644 index 1870d056..00000000 --- a/tests/unit/test_d5c_diag_ncomp96.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Quick diagnostic: test FMHA with n_comp=96 and sink_bias directly. -""" -import torch -import math -import cutlass.cute as cute -import cutlass.torch as ct -import cuda.bindings.driver as cuda -from dsv4.kernels.attention.fmha import FmhaKernel - -hd = 64; m = 128; n_comp = 96; n_swa = 32; n_total = 128 -swa_len = 32 # all SWA valid -scale = 1.0 / math.sqrt(hd) -torch.manual_seed(42) - -q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') -k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda') -v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda') -k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda') -v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda') -attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda') - -k_combined = torch.cat([k_comp, k_swa], dim=0) -v_combined = torch.cat([v_comp, v_swa], dim=0) - -# Reference -qf = q[:, :, 0].float() -kf = k_combined[:, :, 0].float() -vf = v_combined.float() -scores = qf @ kf.T * scale -scores[:, n_comp:] += 0.5 # sink bias on SWA -ref = (torch.softmax(scores, dim=-1) @ vf).to(torch.bfloat16) -lse_ref = torch.logsumexp(scores, dim=-1) - -# Kernel -stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) -kernel = FmhaKernel(head_dim=hd, s_k=n_total, normalize=False, - apply_swa_mask=False, is_causal=False, - n_comp=n_comp, apply_sink_bias=True) - -c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') -lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') -rs_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') - -def to_cute(t): return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t)) -mQ = to_cute(q); mK = to_cute(k_combined) -mV = to_cute(v_combined.unsqueeze(-1).contiguous()) -mC = to_cute(c_out); mLSE = to_cute(lse_out); mRS = to_cute(rs_out) -mSB = to_cute(attn_sink) - -compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, - swa_len=swa_len, sink_bias=mSB, row_sums=mRS) -compiled(mQ, mK, mV, mC, stream, mLSE, - swa_len=swa_len, sink_bias=mSB, row_sums=mRS) -torch.cuda.synchronize() - -o_k = c_out[:, :, 0].float() / rs_out[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30) -cos = torch.nn.functional.cosine_similarity(o_k.flatten().unsqueeze(0), ref.flatten().unsqueeze(0).float()).item() -lse_k = lse_out[:, 0, 0].float() -print(f'n_comp=96, sink_bias=0.5: cos {cos:.6f} {("PASS" if cos > 0.99 else "FAIL")}') -print(f'LSE: kernel[0]={lse_k[0].item():.4f} ref[0]={lse_ref[0].item():.4f} diff={abs(lse_k[0].item()-lse_ref[0].item()):.6f}') -if cos < 0.99: - print(f' kernel[0,:4]={o_k[0,:4].tolist()}') - print(f' ref[0,:4]={ref[0,:4].tolist()}')