D5c multi-tile: VERIFIED cos 0.999996 with Python KV merge + sink bias
Both segments (compressed+SWA with n_comp=96, and SWA-only with n_comp=0) pass individually at cos 0.999996. The Python KV merge produces the correct combined attention at cos 0.999996. Key: n_comp is compile-time, so separate kernel instances are needed for segments with different n_comp values. Production code would use a kernel cache keyed on (n_comp, apply_sink_bias, ...).
This commit is contained in:
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Quick diagnostic: test FMHA with n_comp=96 and sink_bias directly.
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
hd = 64; m = 128; n_comp = 96; n_swa = 32; n_total = 128
|
||||
swa_len = 32 # all SWA valid
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
torch.manual_seed(42)
|
||||
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
|
||||
attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda')
|
||||
|
||||
k_combined = torch.cat([k_comp, k_swa], dim=0)
|
||||
v_combined = torch.cat([v_comp, v_swa], dim=0)
|
||||
|
||||
# Reference
|
||||
qf = q[:, :, 0].float()
|
||||
kf = k_combined[:, :, 0].float()
|
||||
vf = v_combined.float()
|
||||
scores = qf @ kf.T * scale
|
||||
scores[:, n_comp:] += 0.5 # sink bias on SWA
|
||||
ref = (torch.softmax(scores, dim=-1) @ vf).to(torch.bfloat16)
|
||||
lse_ref = torch.logsumexp(scores, dim=-1)
|
||||
|
||||
# Kernel
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=n_total, normalize=False,
|
||||
apply_swa_mask=False, is_causal=False,
|
||||
n_comp=n_comp, apply_sink_bias=True)
|
||||
|
||||
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
rs_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
def to_cute(t): return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
|
||||
mQ = to_cute(q); mK = to_cute(k_combined)
|
||||
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
|
||||
mC = to_cute(c_out); mLSE = to_cute(lse_out); mRS = to_cute(rs_out)
|
||||
mSB = to_cute(attn_sink)
|
||||
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE,
|
||||
swa_len=swa_len, sink_bias=mSB, row_sums=mRS)
|
||||
compiled(mQ, mK, mV, mC, stream, mLSE,
|
||||
swa_len=swa_len, sink_bias=mSB, row_sums=mRS)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
o_k = c_out[:, :, 0].float() / rs_out[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
|
||||
cos = torch.nn.functional.cosine_similarity(o_k.flatten().unsqueeze(0), ref.flatten().unsqueeze(0).float()).item()
|
||||
lse_k = lse_out[:, 0, 0].float()
|
||||
print(f'n_comp=96, sink_bias=0.5: cos {cos:.6f} {("PASS" if cos > 0.99 else "FAIL")}')
|
||||
print(f'LSE: kernel[0]={lse_k[0].item():.4f} ref[0]={lse_ref[0].item():.4f} diff={abs(lse_k[0].item()-lse_ref[0].item()):.6f}')
|
||||
if cos < 0.99:
|
||||
print(f' kernel[0,:4]={o_k[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref[0,:4].tolist()}')
|
||||
Reference in New Issue
Block a user