For all-SWA segments (n_comp=0), sink bias still needs to be applied to all positions. The apply_sink_bias flag controls compilation of the sink bias code path, independent of n_comp offset.
305 lines
11 KiB
Python
305 lines
11 KiB
Python
"""
|
|
FMHA D5c: Fused sparse + SWA attention via combined KV + sink bias.
|
|
|
|
Mathematical insight: the sink merge is equivalent to a single attention
|
|
pass over the concatenated KV with a logit bias (attn_sink) applied to
|
|
the SWA portion. No two passes needed, no merge epilogue.
|
|
|
|
S = [S_comp, S_swa + attn_sink]
|
|
O = softmax(S) @ [V_comp; V_swa]
|
|
|
|
This is identical to:
|
|
O = (exp(lse_sparse)*O_sparse + exp(sink)*exp(lse_swa)*O_swa)
|
|
/ (exp(lse_sparse) + exp(sink)*exp(lse_swa))
|
|
|
|
The kernel changes are minimal:
|
|
1. K = [compressed_K; swa_K], V = [compressed_V; swa_V]
|
|
2. n_comp = length of compressed KV (sink bias applies to positions >= n_comp)
|
|
3. attn_sink = per-head logit bias for SWA positions
|
|
4. D3/D4 masking applies to SWA region (positions >= n_comp)
|
|
|
|
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5c_fused.py
|
|
"""
|
|
import torch
|
|
import math
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as ct
|
|
import cuda.bindings.driver as cuda
|
|
from dsv4.kernels.attention.fmha import FmhaKernel
|
|
|
|
|
|
def reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa,
|
|
attn_sink, scale, swa_len, is_causal=False):
|
|
"""FP32 reference: single softmax over combined KV with sink bias on SWA."""
|
|
m, hd = q.shape
|
|
n_comp = k_comp.shape[0]
|
|
n_swa = k_swa.shape[0]
|
|
|
|
# Concatenate KV
|
|
k_combined = torch.cat([k_comp, k_swa], dim=0) # (n_comp + n_swa, hd)
|
|
v_combined = torch.cat([v_comp, v_swa], dim=0)
|
|
|
|
# Compute combined logits
|
|
scores = torch.matmul(q.float(), k_combined.float().T) * scale # (m, n_comp + n_swa)
|
|
|
|
# Add sink bias to SWA positions
|
|
scores[:, n_comp:] += attn_sink
|
|
|
|
# D3: SWA length mask (only SWA region)
|
|
if swa_len < n_swa:
|
|
scores[:, n_comp + swa_len:] = float('-inf')
|
|
|
|
# D4: causal mask (only SWA region)
|
|
if is_causal:
|
|
# Within SWA region: mask k_coord > m_coord
|
|
for i in range(m):
|
|
for j in range(n_swa):
|
|
if j > i: # k_coord > m_coord
|
|
scores[i, n_comp + j] = float('-inf')
|
|
|
|
# Softmax + PV
|
|
max_s = scores.max(dim=-1, keepdim=True).values
|
|
exp_s = (scores - max_s).exp()
|
|
sum_s = exp_s.sum(dim=-1, keepdim=True).clamp(min=1e-30)
|
|
p = exp_s / sum_s
|
|
o = torch.matmul(p, v_combined.float())
|
|
return o.to(torch.bfloat16)
|
|
|
|
|
|
def reference_sink_merge(q, k_comp, v_comp, k_swa, v_swa,
|
|
attn_sink, scale, swa_len, is_causal=False):
|
|
"""FP32 reference: separate attention + sink merge (original D5b formula)."""
|
|
m, hd = q.shape
|
|
n_comp = k_comp.shape[0]
|
|
n_swa = k_swa.shape[0]
|
|
|
|
# Compressed KV attention (no mask)
|
|
attn_comp = torch.matmul(q.float(), k_comp.float().T) * scale
|
|
o_norm_comp = torch.softmax(attn_comp, dim=-1) @ v_comp.float()
|
|
lse_comp = torch.logsumexp(attn_comp, dim=-1, keepdim=True) # (m, 1)
|
|
|
|
# SWA KV attention (with swa_len mask)
|
|
attn_swa = torch.matmul(q.float(), k_swa.float().T) * scale
|
|
if swa_len < n_swa:
|
|
attn_swa[:, swa_len:] = float('-inf')
|
|
if is_causal:
|
|
for i in range(m):
|
|
for j in range(n_swa):
|
|
if j > i:
|
|
attn_swa[i, j] = float('-inf')
|
|
o_norm_swa = torch.softmax(attn_swa, dim=-1) @ v_swa.float()
|
|
lse_swa = torch.logsumexp(attn_swa, dim=-1, keepdim=True)
|
|
|
|
# Sink merge (normalized formula)
|
|
exp_sink = math.exp(attn_sink)
|
|
numerator = lse_comp.exp() * o_norm_comp + exp_sink * lse_swa.exp() * o_norm_swa
|
|
denominator = (lse_comp.exp() + exp_sink * lse_swa.exp()).clamp(min=1e-30)
|
|
o = numerator / denominator
|
|
return o.to(torch.bfloat16)
|
|
|
|
|
|
def test_d5c_combined():
|
|
print("=== Stage D5c: Fused Sparse+SWA via Combined KV + Sink Bias ===\n")
|
|
|
|
hd = 64
|
|
m = 128 # query rows
|
|
n_comp = 64 # compressed KV length (fit in single 128-wide KV tile)
|
|
n_swa = 64 # SWA window length
|
|
n_total = n_comp + n_swa # combined KV length = 128 (1 KV tile)
|
|
swa_len = 40 # actual SWA fill
|
|
scale = 1.0 / math.sqrt(hd)
|
|
torch.manual_seed(42)
|
|
|
|
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
|
|
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
# Sink weight (in log domain, per-head). n_h=1 so it's a scalar.
|
|
attn_sink_val = 0.5
|
|
attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda')
|
|
|
|
# === FP32 References ===
|
|
qf = q[:, :, 0]
|
|
|
|
# Reference 1: Combined softmax with sink bias (our kernel's approach)
|
|
ref_combined = reference_combined_attention(
|
|
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
|
|
attn_sink_val, scale, swa_len
|
|
)
|
|
|
|
# Reference 2: Separate attention + sink merge (original D5b formula)
|
|
ref_merge = reference_sink_merge(
|
|
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
|
|
attn_sink_val, scale, swa_len
|
|
)
|
|
|
|
# Verify the two references agree
|
|
cos_ref = torch.nn.functional.cosine_similarity(
|
|
ref_combined.flatten().unsqueeze(0).float(),
|
|
ref_merge.flatten().unsqueeze(0).float()
|
|
).item()
|
|
print(f"Reference: combined softmax vs sink merge cos = {cos_ref:.6f}")
|
|
assert cos_ref > 0.999, f"References don't match: cos={cos_ref}"
|
|
|
|
# === Kernel ===
|
|
# Concatenate KV for the kernel
|
|
k_combined = torch.cat([k_comp, k_swa], dim=0) # (n_total, hd, 1)
|
|
v_combined = torch.cat([v_comp, v_swa], dim=0) # (n_total, hd)
|
|
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
kernel = FmhaKernel(
|
|
head_dim=hd,
|
|
s_k=n_total, # combined KV length
|
|
normalize=False, # D5a: emit un-normalized O + LSE
|
|
apply_swa_mask=True, # D3: mask SWA positions
|
|
is_causal=False, # D4: no causal mask for this test
|
|
n_comp=n_comp, # D5c: compressed KV length (sink bias starts here)
|
|
apply_sink_bias=True, # D5c: enable sink bias logit modification
|
|
)
|
|
|
|
# Allocate output
|
|
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
row_sum_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
|
|
# Prepare CuTe tensors
|
|
def to_cute(t):
|
|
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
|
|
|
|
mQ = to_cute(q)
|
|
mK = to_cute(k_combined)
|
|
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
|
|
mC = to_cute(c_out)
|
|
mLSE = to_cute(lse_out)
|
|
mRowSums = to_cute(row_sum_out)
|
|
|
|
# Compile
|
|
print('Compiling D5c kernel (combined KV + sink bias)...', flush=True)
|
|
mSinkBias = to_cute(attn_sink)
|
|
compiled = cute.compile(
|
|
kernel, mQ, mK, mV, mC, stream, mLSE,
|
|
swa_len=swa_len, sink_bias=mSinkBias, row_sums=mRowSums,
|
|
)
|
|
|
|
# Run
|
|
print('Running D5c kernel...', flush=True)
|
|
compiled(
|
|
mQ, mK, mV, mC, stream, mLSE,
|
|
swa_len=swa_len, sink_bias=mSinkBias, row_sums=mRowSums,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
# Check results
|
|
# Kernel outputs UN-NORMALIZED O (normalize=False). Normalize using per-row row_sum.
|
|
# O_norm[i] = O_unnorm[i] / row_sum[i]
|
|
o_kernel_unnorm = c_out[:, :, 0].float() # (m, hd)
|
|
row_sums = row_sum_out[:, 0, 0].float() # (m,)
|
|
|
|
# Normalize each row by its row_sum
|
|
o_kernel = o_kernel_unnorm / row_sums.unsqueeze(1).clamp(min=1e-30)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
o_kernel.flatten().unsqueeze(0),
|
|
ref_combined.flatten().unsqueeze(0).float()
|
|
).item()
|
|
max_abs = (o_kernel - ref_combined.float()).abs().max().item()
|
|
|
|
status = "PASS" if cos >= 0.99 else "FAIL"
|
|
print(f'\nD5c result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
|
|
|
|
if cos < 0.99:
|
|
print(f' kernel[0,:4]={o_kernel[0,:4].tolist()}')
|
|
print(f' ref[0,:4]={ref_combined[0,:4].tolist()}')
|
|
print(f' row_sum range: {row_sums.min().item():.4f} to {row_sums.max().item():.4f}')
|
|
print(f' LSE range: {lse_out[:,0,0].min().item():.4f} to {lse_out[:,0,0].max().item():.4f}')
|
|
|
|
|
|
def test_d5c_with_causal():
|
|
"""D5c with causal mask on SWA branch."""
|
|
print("\n=== Stage D5c: Fused Sparse+SWA with Causal Mask ===\n")
|
|
|
|
hd = 64
|
|
m = 128
|
|
n_comp = 64
|
|
n_swa = 64
|
|
n_total = n_comp + n_swa # 128, single KV tile
|
|
swa_len = 48 # partially filled SWA
|
|
scale = 1.0 / math.sqrt(hd)
|
|
torch.manual_seed(123)
|
|
|
|
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
|
|
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
|
|
|
|
attn_sink_val = 0.3
|
|
attn_sink = torch.tensor([attn_sink_val], dtype=torch.float32, device='cuda')
|
|
qf = q[:, :, 0]
|
|
|
|
ref = reference_combined_attention(
|
|
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
|
|
attn_sink_val, scale, swa_len, is_causal=True
|
|
)
|
|
|
|
# Concatenate KV
|
|
k_combined = torch.cat([k_comp, k_swa], dim=0)
|
|
v_combined = torch.cat([v_comp, v_swa], dim=0)
|
|
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
kernel = FmhaKernel(
|
|
head_dim=hd,
|
|
s_k=n_total,
|
|
normalize=False,
|
|
apply_swa_mask=True,
|
|
is_causal=True,
|
|
n_comp=n_comp,
|
|
apply_sink_bias=True,
|
|
)
|
|
|
|
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
|
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
row_sum_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
|
|
|
def to_cute(t):
|
|
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
|
|
|
|
mQ = to_cute(q)
|
|
mK = to_cute(k_combined)
|
|
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
|
|
mC = to_cute(c_out)
|
|
mLSE = to_cute(lse_out)
|
|
mRowSums = to_cute(row_sum_out)
|
|
mSinkBias = to_cute(attn_sink)
|
|
compiled = cute.compile(
|
|
kernel, mQ, mK, mV, mC, stream, mLSE,
|
|
swa_len=swa_len, sink_bias=mSinkBias, row_sums=mRowSums,
|
|
)
|
|
compiled(
|
|
mQ, mK, mV, mC, stream, mLSE,
|
|
swa_len=swa_len, sink_bias=mSinkBias, row_sums=mRowSums,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
o_kernel_unnorm = c_out[:, :, 0].float()
|
|
row_sums = row_sum_out[:, 0, 0].float()
|
|
o_kernel = o_kernel_unnorm / row_sums.unsqueeze(1).clamp(min=1e-30)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
o_kernel.flatten().unsqueeze(0),
|
|
ref.flatten().unsqueeze(0).float()
|
|
).item()
|
|
max_abs = (o_kernel - ref.float()).abs().max().item()
|
|
|
|
status = "PASS" if cos >= 0.99 else "FAIL"
|
|
print(f'D5c causal result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_d5c_combined()
|
|
test_d5c_with_causal()
|