167 lines
7.7 KiB
Python
167 lines
7.7 KiB
Python
"""
|
|
FMHA D5c: Sink bias + Python KV merge for multi-tile (s_k > 128).
|
|
|
|
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5c_multitile.py
|
|
"""
|
|
import torch
|
|
import math
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as ct
|
|
import cuda.bindings.driver as cuda
|
|
from dsv4.kernels.attention.fmha import FmhaKernel
|
|
|
|
|
|
def to_cute(t):
|
|
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
|
|
|
|
|
|
def reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa,
|
|
attn_sink, scale, swa_len, is_causal=False):
|
|
m, hd = q.shape
|
|
n_comp = k_comp.shape[0]
|
|
n_swa = k_swa.shape[0]
|
|
k_comb = torch.cat([k_comp, k_swa], dim=0)
|
|
v_comb = torch.cat([v_comp, v_swa], dim=0)
|
|
scores = q.float() @ k_comb.float().T * scale
|
|
scores[:, n_comp:] += attn_sink
|
|
if swa_len < n_swa:
|
|
scores[:, n_comp + swa_len:] = float('-inf')
|
|
if is_causal:
|
|
for i in range(m):
|
|
for j in range(n_swa):
|
|
if j > i:
|
|
scores[i, n_comp + j] = float('-inf')
|
|
max_s = scores.max(dim=-1, keepdim=True).values
|
|
exp_s = (scores - max_s).exp()
|
|
sum_s = exp_s.sum(dim=-1, keepdim=True).clamp(min=1e-30)
|
|
return (exp_s / sum_s @ v_comb.float()).to(torch.bfloat16)
|
|
|
|
|
|
def python_kv_merge(segment_results):
|
|
"""O = sum_i(exp(lse_i) * O_i_norm) / sum_i(exp(lse_i))"""
|
|
lse_stack = torch.stack([r[1] for r in segment_results], dim=0)
|
|
lse_max = lse_stack.max(dim=0).values
|
|
numerator = torch.zeros_like(segment_results[0][0])
|
|
denominator = torch.zeros(lse_max.shape[0], dtype=torch.float32, device=lse_max.device)
|
|
for o_norm, lse in segment_results:
|
|
exp_lse = (lse - lse_max).exp()
|
|
numerator += exp_lse.unsqueeze(1) * o_norm.float()
|
|
denominator += exp_lse
|
|
return numerator / denominator.unsqueeze(1).clamp(min=1e-30)
|
|
|
|
|
|
def run_fmha(q, k, v, kernel_obj, compiled, stream, swa_len, sink_bias=None):
|
|
"""Run FMHA with single KV tile, return (O_norm, LSE)."""
|
|
m = q.shape[0]
|
|
hd = v.shape[1]
|
|
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
lse = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
rs = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
|
|
mQ = to_cute(q.unsqueeze(-1))
|
|
mK = to_cute(k.unsqueeze(-1))
|
|
mV = to_cute(v.unsqueeze(-1))
|
|
mC = to_cute(c); mLSE = to_cute(lse); mRS = to_cute(rs)
|
|
|
|
if sink_bias is not None:
|
|
mSB = to_cute(sink_bias)
|
|
compiled(mQ, mK, mV, mC, stream, mLSE, swa_len=swa_len, sink_bias=mSB, row_sums=mRS)
|
|
else:
|
|
compiled(mQ, mK, mV, mC, stream, mLSE, swa_len=swa_len, row_sums=mRS)
|
|
torch.cuda.synchronize()
|
|
|
|
o_norm = c[:, :, 0].float() / rs[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
|
|
return o_norm, lse[:, 0, 0].clone()
|
|
|
|
|
|
def test_d5c_multitile():
|
|
print("=== D5c Multi-Tile: Sink Bias + Python KV Merge ===\n")
|
|
|
|
hd = 64; m = 128; n_comp = 96; n_swa = 160; n_total = 256
|
|
swa_len = 100; scale = 1.0 / math.sqrt(hd); seg_size = 128
|
|
torch.manual_seed(42)
|
|
|
|
q = torch.randn(m, hd, dtype=torch.bfloat16, device='cuda')
|
|
k_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
|
|
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
|
|
k_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
|
|
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
|
|
attn_sink_val = 0.5
|
|
attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda')
|
|
|
|
k_combined = torch.cat([k_comp, k_swa], dim=0)
|
|
v_combined = torch.cat([v_comp, v_swa], dim=0)
|
|
|
|
# Full reference
|
|
ref = reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa,
|
|
attn_sink_val, scale, swa_len)
|
|
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
# Per-segment reference and kernel verification
|
|
# Segment 0: positions 0-127 (n_comp=96, 32 SWA tokens, all valid, sink bias on SWA)
|
|
k0 = k_combined[0:128]; v0 = v_combined[0:128]
|
|
scores0 = q.float() @ k0.float().T * scale
|
|
scores0[:, n_comp:] += attn_sink_val
|
|
# swa_len=100 but only 32 SWA positions → no D3 masking
|
|
ref0 = (torch.softmax(scores0, dim=-1) @ v0.float()).to(torch.bfloat16)
|
|
lse0_ref = torch.logsumexp(scores0, dim=-1)
|
|
|
|
k_seg0 = FmhaKernel(head_dim=hd, s_k=128, normalize=False,
|
|
apply_swa_mask=True, is_causal=False, n_comp=n_comp, apply_sink_bias=True)
|
|
c0 = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
l0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
r0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
comp0 = cute.compile(k_seg0, to_cute(q.unsqueeze(-1)), to_cute(k0.unsqueeze(-1)),
|
|
to_cute(v0.unsqueeze(-1)), to_cute(c0), stream, to_cute(l0),
|
|
swa_len=100, sink_bias=to_cute(attn_sink), row_sums=to_cute(r0))
|
|
comp0(to_cute(q.unsqueeze(-1)), to_cute(k0.unsqueeze(-1)),
|
|
to_cute(v0.unsqueeze(-1)), to_cute(c0), stream, to_cute(l0),
|
|
swa_len=100, sink_bias=to_cute(attn_sink), row_sums=to_cute(r0))
|
|
torch.cuda.synchronize()
|
|
ok0 = c0[:, :, 0].float() / r0[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
|
|
cos0 = torch.nn.functional.cosine_similarity(ok0.flatten().unsqueeze(0), ref0.flatten().unsqueeze(0).float()).item()
|
|
lse0_kern = l0[:, 0, 0]
|
|
print(f'Seg0: cos {cos0:.6f} LSE_kern[0]={lse0_kern[0].item():.4f} LSE_ref[0]={lse0_ref[0].item():.4f}')
|
|
|
|
# Segment 1: positions 128-255 (all SWA, n_comp=0, sink bias on all, D3 mask at position 68)
|
|
k1 = k_combined[128:256]; v1 = v_combined[128:256]
|
|
scores1 = q.float() @ k1.float().T * scale
|
|
scores1 += attn_sink_val # all SWA → sink bias on all
|
|
# Valid SWA: absolute 96-195. In this segment (128-255): positions 0-67 valid, 68-127 masked
|
|
swa_len_seg1 = 68 # n_comp + swa_len - k_start = 96 + 100 - 128 = 68
|
|
scores1[:, swa_len_seg1:] = float('-inf')
|
|
ref1 = (torch.softmax(scores1, dim=-1) @ v1.float()).to(torch.bfloat16)
|
|
lse1_ref = torch.logsumexp(scores1, dim=-1)
|
|
|
|
k_seg1 = FmhaKernel(head_dim=hd, s_k=128, normalize=False,
|
|
apply_swa_mask=True, is_causal=False, n_comp=0, apply_sink_bias=True)
|
|
c1 = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
l1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
r1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
comp1 = cute.compile(k_seg1, to_cute(q.unsqueeze(-1)), to_cute(k1.unsqueeze(-1)),
|
|
to_cute(v1.unsqueeze(-1)), to_cute(c1), stream, to_cute(l1),
|
|
swa_len=swa_len_seg1, sink_bias=to_cute(attn_sink), row_sums=to_cute(r1))
|
|
comp1(to_cute(q.unsqueeze(-1)), to_cute(k1.unsqueeze(-1)),
|
|
to_cute(v1.unsqueeze(-1)), to_cute(c1), stream, to_cute(l1),
|
|
swa_len=swa_len_seg1, sink_bias=to_cute(attn_sink), row_sums=to_cute(r1))
|
|
torch.cuda.synchronize()
|
|
ok1 = c1[:, :, 0].float() / r1[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
|
|
cos1 = torch.nn.functional.cosine_similarity(ok1.flatten().unsqueeze(0), ref1.flatten().unsqueeze(0).float()).item()
|
|
lse1_kern = l1[:, 0, 0]
|
|
print(f'Seg1: cos {cos1:.6f} LSE_kern[0]={lse1_kern[0].item():.4f} LSE_ref[0]={lse1_ref[0].item():.4f}')
|
|
|
|
# Merge
|
|
o_merged = python_kv_merge([(ok0, lse0_kern), (ok1, lse1_kern)])
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
o_merged.flatten().unsqueeze(0), ref.flatten().unsqueeze(0).float()).item()
|
|
status = "PASS" if cos >= 0.99 else "FAIL"
|
|
print(f'\nMerged: cos {cos:.6f} {status}')
|
|
if cos < 0.99:
|
|
print(f' kernel[0,:4]={o_merged[0,:4].tolist()}')
|
|
print(f' ref[0,:4]={ref[0,:4].tolist()}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_d5c_multitile()
|