diag: test n_comp=96 with sink bias directly

This commit is contained in:
2026-05-26 15:33:38 +00:00
parent 25b236fe00
commit aa2df1a202
2 changed files with 65 additions and 1 deletions

View File

@@ -0,0 +1,64 @@
"""
Quick diagnostic: test FMHA with n_comp=96 and sink_bias directly.
"""
import torch
import math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
hd = 64; m = 128; n_comp = 96; n_swa = 32; n_total = 128
swa_len = 32 # all SWA valid
scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda')
k_combined = torch.cat([k_comp, k_swa], dim=0)
v_combined = torch.cat([v_comp, v_swa], dim=0)
# Reference
qf = q[:, :, 0].float()
kf = k_combined[:, :, 0].float()
vf = v_combined.float()
scores = qf @ kf.T * scale
scores[:, n_comp:] += 0.5 # sink bias on SWA
ref = (torch.softmax(scores, dim=-1) @ vf).to(torch.bfloat16)
lse_ref = torch.logsumexp(scores, dim=-1)
# Kernel
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaKernel(head_dim=hd, s_k=n_total, normalize=False,
apply_swa_mask=False, is_causal=False,
n_comp=n_comp, apply_sink_bias=True)
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
def to_cute(t): return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
mQ = to_cute(q); mK = to_cute(k_combined)
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
mC = to_cute(c_out); mLSE = to_cute(lse_out); mRS = to_cute(rs_out)
mSB = to_cute(attn_sink)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=mSB, row_sums=mRS)
compiled(mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=mSB, row_sums=mRS)
torch.cuda.synchronize()
o_k = c_out[:, :, 0].float() / rs_out[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
cos = torch.nn.functional.cosine_similarity(o_k.flatten().unsqueeze(0), ref.flatten().unsqueeze(0).float()).item()
lse_k = lse_out[:, 0, 0].float()
print(f'n_comp=96, sink_bias=0.5: cos {cos:.6f} {("PASS" if cos > 0.99 else "FAIL")}')
print(f'LSE: kernel[0]={lse_k[0].item():.4f} ref[0]={lse_ref[0].item():.4f} diff={abs(lse_k[0].item()-lse_ref[0].item()):.6f}')
if cos < 0.99:
print(f' kernel[0,:4]={o_k[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')

View File

@@ -134,7 +134,7 @@ def test_d5c_multitile():
k_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
attn_sink_val = 0.0 # Start with no sink bias to isolate issues
attn_sink_val = 0.5
attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda')
# Reference