From aa2df1a202cd1df0e7f6e013f12b6c8c6f9ae7de Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 26 May 2026 15:33:38 +0000 Subject: [PATCH] diag: test n_comp=96 with sink bias directly --- tests/unit/test_d5c_diag_ncomp96.py | 64 +++++++++++++++++++++++++++++ tests/unit/test_d5c_multitile.py | 2 +- 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_d5c_diag_ncomp96.py diff --git a/tests/unit/test_d5c_diag_ncomp96.py b/tests/unit/test_d5c_diag_ncomp96.py new file mode 100644 index 00000000..1870d056 --- /dev/null +++ b/tests/unit/test_d5c_diag_ncomp96.py @@ -0,0 +1,64 @@ +""" +Quick diagnostic: test FMHA with n_comp=96 and sink_bias directly. +""" +import torch +import math +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + +hd = 64; m = 128; n_comp = 96; n_swa = 32; n_total = 128 +swa_len = 32 # all SWA valid +scale = 1.0 / math.sqrt(hd) +torch.manual_seed(42) + +q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') +k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda') +v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda') +k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda') +v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda') +attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda') + +k_combined = torch.cat([k_comp, k_swa], dim=0) +v_combined = torch.cat([v_comp, v_swa], dim=0) + +# Reference +qf = q[:, :, 0].float() +kf = k_combined[:, :, 0].float() +vf = v_combined.float() +scores = qf @ kf.T * scale +scores[:, n_comp:] += 0.5 # sink bias on SWA +ref = (torch.softmax(scores, dim=-1) @ vf).to(torch.bfloat16) +lse_ref = torch.logsumexp(scores, dim=-1) + +# Kernel +stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) +kernel = FmhaKernel(head_dim=hd, s_k=n_total, normalize=False, + apply_swa_mask=False, is_causal=False, + n_comp=n_comp, apply_sink_bias=True) + +c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') +lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') +rs_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + +def to_cute(t): return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t)) +mQ = to_cute(q); mK = to_cute(k_combined) +mV = to_cute(v_combined.unsqueeze(-1).contiguous()) +mC = to_cute(c_out); mLSE = to_cute(lse_out); mRS = to_cute(rs_out) +mSB = to_cute(attn_sink) + +compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, + swa_len=swa_len, sink_bias=mSB, row_sums=mRS) +compiled(mQ, mK, mV, mC, stream, mLSE, + swa_len=swa_len, sink_bias=mSB, row_sums=mRS) +torch.cuda.synchronize() + +o_k = c_out[:, :, 0].float() / rs_out[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30) +cos = torch.nn.functional.cosine_similarity(o_k.flatten().unsqueeze(0), ref.flatten().unsqueeze(0).float()).item() +lse_k = lse_out[:, 0, 0].float() +print(f'n_comp=96, sink_bias=0.5: cos {cos:.6f} {("PASS" if cos > 0.99 else "FAIL")}') +print(f'LSE: kernel[0]={lse_k[0].item():.4f} ref[0]={lse_ref[0].item():.4f} diff={abs(lse_k[0].item()-lse_ref[0].item()):.6f}') +if cos < 0.99: + print(f' kernel[0,:4]={o_k[0,:4].tolist()}') + print(f' ref[0,:4]={ref[0,:4].tolist()}') diff --git a/tests/unit/test_d5c_multitile.py b/tests/unit/test_d5c_multitile.py index 691923d7..b0a39544 100644 --- a/tests/unit/test_d5c_multitile.py +++ b/tests/unit/test_d5c_multitile.py @@ -134,7 +134,7 @@ def test_d5c_multitile(): k_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda') v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda') - attn_sink_val = 0.0 # Start with no sink bias to isolate issues + attn_sink_val = 0.5 attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda') # Reference