D5c multi-tile: VERIFIED cos 0.999996 with Python KV merge + sink bias

Both segments (compressed+SWA with n_comp=96, and SWA-only with n_comp=0)
pass individually at cos 0.999996. The Python KV merge produces the
correct combined attention at cos 0.999996.

Key: n_comp is compile-time, so separate kernel instances are needed
for segments with different n_comp values. Production code would use
a kernel cache keyed on (n_comp, apply_sink_bias, ...).
This commit is contained in:
2026-05-26 15:40:45 +00:00
parent c9eab3c7e0
commit 487d960a6a

View File

@@ -1,64 +0,0 @@
"""
Quick diagnostic: test FMHA with n_comp=96 and sink_bias directly.
"""
import torch
import math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
hd = 64; m = 128; n_comp = 96; n_swa = 32; n_total = 128
swa_len = 32 # all SWA valid
scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda')
k_combined = torch.cat([k_comp, k_swa], dim=0)
v_combined = torch.cat([v_comp, v_swa], dim=0)
# Reference
qf = q[:, :, 0].float()
kf = k_combined[:, :, 0].float()
vf = v_combined.float()
scores = qf @ kf.T * scale
scores[:, n_comp:] += 0.5 # sink bias on SWA
ref = (torch.softmax(scores, dim=-1) @ vf).to(torch.bfloat16)
lse_ref = torch.logsumexp(scores, dim=-1)
# Kernel
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaKernel(head_dim=hd, s_k=n_total, normalize=False,
apply_swa_mask=False, is_causal=False,
n_comp=n_comp, apply_sink_bias=True)
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
def to_cute(t): return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
mQ = to_cute(q); mK = to_cute(k_combined)
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
mC = to_cute(c_out); mLSE = to_cute(lse_out); mRS = to_cute(rs_out)
mSB = to_cute(attn_sink)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=mSB, row_sums=mRS)
compiled(mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=mSB, row_sums=mRS)
torch.cuda.synchronize()
o_k = c_out[:, :, 0].float() / rs_out[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
cos = torch.nn.functional.cosine_similarity(o_k.flatten().unsqueeze(0), ref.flatten().unsqueeze(0).float()).item()
lse_k = lse_out[:, 0, 0].float()
print(f'n_comp=96, sink_bias=0.5: cos {cos:.6f} {("PASS" if cos > 0.99 else "FAIL")}')
print(f'LSE: kernel[0]={lse_k[0].item():.4f} ref[0]={lse_ref[0].item():.4f} diff={abs(lse_k[0].item()-lse_ref[0].item()):.6f}')
if cos < 0.99:
print(f' kernel[0,:4]={o_k[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')