D5c: add sink bias (attn_sink) logit modification to FMHA kernel

- Add n_comp parameter: compressed KV length, sink bias applies to positions >= n_comp
- Add sink_bias parameter: per-head FP32 logit bias for SWA positions
- D3 mask updated: kv_pos >= n_comp + swa_len (backward compatible when n_comp=0)
- D4 causal mask updated: compare SWA-relative position (kv_pos - n_comp) with m_coord
- Mathematical insight: sink merge = single softmax over [S_comp, S_swa + attn_sink]
- Add test_d5c_fused.py with combined KV + sink bias test
This commit is contained in:
2026-05-26 14:59:52 +00:00
parent 60a6f2d296
commit 9d64434954
2 changed files with 326 additions and 14 deletions

View File

@@ -16,7 +16,11 @@ import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False):
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None):
# D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to
# positions >= n_comp. D3/D4 masks also only apply to SWA region.
# When n_comp is None or 0, no sink bias is applied (backward compatible).
self.n_comp = n_comp if n_comp is not None else 0
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
@@ -105,7 +109,7 @@ class FmhaKernel:
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None):
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
@@ -138,12 +142,21 @@ class FmhaKernel:
swa_len = Int32(2147483647)
else:
swa_len = Int32(swa_len)
# D5c: sink_bias is a per-head FP32 logit bias applied to SWA positions.
# When None, pass 0.0 (no bias). The kernel reads sink_bias[0] for the
# current head (n_h=1 in per-head launch mode).
if const_expr(sink_bias is None):
# D5c: sink_bias not provided. Create a dummy tensor pointing to valid memory.
# Never actually read (const_expr(self.n_comp > 0) guards the read).
sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,)))
else:
sink_bias = ct.from_dlpack(sink_bias).mark_layout_dynamic(leading_dim=ct.get_leading_dim(sink_bias))
# Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension
# For single-head (n_h=1): grid=(1,1,1) — backward compatible
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len):
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
@@ -411,29 +424,44 @@ class FmhaKernel:
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# D3/D4: In-kernel logit masking.
# After loading S from TMEM, mask invalid positions to -inf.
# D3/D4/D5c: In-kernel logit modification.
# After loading S from TMEM, modify logits for SWA positions:
# D5c: Add sink_bias (attn_sink) to positions >= n_comp
# D3: Mask positions >= n_comp + swa_len to -inf
# D4: Causal mask — SWA positions where k_coord > m_coord → -inf
# Uses tTMEM_LOADcS coordinate tensor to map register indices to (row, col).
# D3: SWA mask — positions >= swa_lens[batch_idx] → -inf
# D4: Causal mask — positions where k_coord > m_coord → -inf
# Both use the same coordinate mapping from tTMEM_LOADcS.
# For kt > 0, absolute KV pos = kt*128 + k_coord.
if const_expr(self.apply_swa_mask or self.is_causal):
if const_expr(self.apply_swa_mask or self.is_causal or self.n_comp > 0):
kt_offset = Int32(kt * 128) # KV position offset for this tile
# Iterate using same coordinate indexing as SMEM-P path
# D5c: Read sink bias once (same for all positions in this head).
# Define unconditionally for CuTeDSL scoping (used when n_comp > 0).
sink_val = Float32(0.0)
if const_expr(self.n_comp > 0):
sink_val = mSinkBias[Int32(0)]
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0] # query row position
k_coord = coord[1] # position within this KV tile
kv_pos = kt_offset + k_coord # absolute KV position
# D5c: Add sink bias to SWA positions (>= n_comp)
if const_expr(self.n_comp > 0):
if kv_pos >= Int32(self.n_comp):
tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val
# D3: SWA length mask
should_mask = Boolean(0)
if const_expr(self.apply_swa_mask):
if kv_pos >= swa_len:
# SWA length applies relative to the SWA region start (n_comp)
# kv_pos >= n_comp + swa_len means the SWA position >= swa_len
if kv_pos >= Int32(self.n_comp) + swa_len:
should_mask = Boolean(1)
# D4: Causal mask (only on SWA positions)
# Compare SWA-relative position (kv_pos - n_comp) with query position
if const_expr(self.is_causal):
if k_coord > m_coord:
should_mask = Boolean(1)
if kv_pos >= Int32(self.n_comp):
swa_pos = kv_pos - Int32(self.n_comp)
if swa_pos > m_coord:
should_mask = Boolean(1)
if should_mask:
tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf

View File

@@ -0,0 +1,284 @@
"""
FMHA D5c: Fused sparse + SWA attention via combined KV + sink bias.
Mathematical insight: the sink merge is equivalent to a single attention
pass over the concatenated KV with a logit bias (attn_sink) applied to
the SWA portion. No two passes needed, no merge epilogue.
S = [S_comp, S_swa + attn_sink]
O = softmax(S) @ [V_comp; V_swa]
This is identical to:
O = (exp(lse_sparse)*O_sparse + exp(sink)*exp(lse_swa)*O_swa)
/ (exp(lse_sparse) + exp(sink)*exp(lse_swa))
The kernel changes are minimal:
1. K = [compressed_K; swa_K], V = [compressed_V; swa_V]
2. n_comp = length of compressed KV (sink bias applies to positions >= n_comp)
3. attn_sink = per-head logit bias for SWA positions
4. D3/D4 masking applies to SWA region (positions >= n_comp)
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5c_fused.py
"""
import torch
import math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa,
attn_sink, scale, swa_len, is_causal=False):
"""FP32 reference: single softmax over combined KV with sink bias on SWA."""
m, hd = q.shape
n_comp = k_comp.shape[0]
n_swa = k_swa.shape[0]
# Concatenate KV
k_combined = torch.cat([k_comp, k_swa], dim=0) # (n_comp + n_swa, hd)
v_combined = torch.cat([v_comp, v_swa], dim=0)
# Compute combined logits
scores = torch.matmul(q.float(), k_combined.float().T) * scale # (m, n_comp + n_swa)
# Add sink bias to SWA positions
scores[:, n_comp:] += attn_sink
# D3: SWA length mask (only SWA region)
if swa_len < n_swa:
scores[:, n_comp + swa_len:] = float('-inf')
# D4: causal mask (only SWA region)
if is_causal:
# Within SWA region: mask k_coord > m_coord
for i in range(m):
for j in range(n_swa):
if j > i: # k_coord > m_coord
scores[i, n_comp + j] = float('-inf')
# Softmax + PV
max_s = scores.max(dim=-1, keepdim=True).values
exp_s = (scores - max_s).exp()
sum_s = exp_s.sum(dim=-1, keepdim=True).clamp(min=1e-30)
p = exp_s / sum_s
o = torch.matmul(p, v_combined.float())
return o.to(torch.bfloat16)
def reference_sink_merge(q, k_comp, v_comp, k_swa, v_swa,
attn_sink, scale, swa_len, is_causal=False):
"""FP32 reference: separate attention + sink merge (original D5b formula)."""
m, hd = q.shape
n_comp = k_comp.shape[0]
n_swa = k_swa.shape[0]
# Compressed KV attention (no mask)
attn_comp = torch.matmul(q.float(), k_comp.float().T) * scale
o_norm_comp = torch.softmax(attn_comp, dim=-1) @ v_comp.float()
lse_comp = torch.logsumexp(attn_comp, dim=-1, keepdim=True) # (m, 1)
# SWA KV attention (with swa_len mask)
attn_swa = torch.matmul(q.float(), k_swa.float().T) * scale
if swa_len < n_swa:
attn_swa[:, swa_len:] = float('-inf')
if is_causal:
for i in range(m):
for j in range(n_swa):
if j > i:
attn_swa[i, j] = float('-inf')
o_norm_swa = torch.softmax(attn_swa, dim=-1) @ v_swa.float()
lse_swa = torch.logsumexp(attn_swa, dim=-1, keepdim=True)
# Sink merge (normalized formula)
exp_sink = attn_sink.exp()
numerator = lse_comp.exp() * o_norm_comp + exp_sink * lse_swa.exp() * o_norm_swa
denominator = (lse_comp.exp() + exp_sink * lse_swa.exp()).clamp(min=1e-30)
o = numerator / denominator
return o.to(torch.bfloat16)
def test_d5c_combined():
print("=== Stage D5c: Fused Sparse+SWA via Combined KV + Sink Bias ===\n")
hd = 64
m = 128 # query rows
n_comp = 128 # compressed KV length
n_swa = 128 # SWA window length
n_total = n_comp + n_swa # combined KV length = 256
swa_len = 64 # actual SWA fill
scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
# Sink weight (in log domain, per-head). n_h=1 so it's a scalar.
attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda')
# === FP32 References ===
qf = q[:, :, 0]
# Reference 1: Combined softmax with sink bias (our kernel's approach)
ref_combined = reference_combined_attention(
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
attn_sink[0].item(), scale, swa_len
)
# Reference 2: Separate attention + sink merge (original D5b formula)
ref_merge = reference_sink_merge(
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
attn_sink[0].item(), scale, swa_len
)
# Verify the two references agree
cos_ref = torch.nn.functional.cosine_similarity(
ref_combined.flatten().unsqueeze(0).float(),
ref_merge.flatten().unsqueeze(0).float()
).item()
print(f"Reference: combined softmax vs sink merge cos = {cos_ref:.6f}")
assert cos_ref > 0.999, f"References don't match: cos={cos_ref}"
# === Kernel ===
# Concatenate KV for the kernel
k_combined = torch.cat([k_comp, k_swa], dim=0) # (n_total, hd, 1)
v_combined = torch.cat([v_comp, v_swa], dim=0) # (n_total, hd)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaKernel(
head_dim=hd,
s_k=n_total, # combined KV length
normalize=False, # D5a: emit un-normalized O + LSE
apply_swa_mask=True, # D3: mask SWA positions
is_causal=False, # D4: no causal mask for this test
n_comp=n_comp, # D5c: compressed KV length (sink bias starts here)
)
# Allocate output
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
# Prepare CuTe tensors
def to_cute(t):
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
mQ = to_cute(q)
mK = to_cute(k_combined)
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
mC = to_cute(c_out)
mLSE = to_cute(lse_out)
# Compile
print('Compiling D5c kernel (combined KV + sink bias)...', flush=True)
compiled = cute.compile(
kernel, mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=attn_sink,
)
# Run
print('Running D5c kernel...', flush=True)
compiled(
mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=attn_sink,
)
torch.cuda.synchronize()
# Check results
o_kernel = c_out[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
o_kernel.flatten().unsqueeze(0),
ref_combined.flatten().unsqueeze(0).float()
).item()
max_abs = (o_kernel - ref_combined.float()).abs().max().item()
status = "PASS" if cos >= 0.95 else "FAIL"
print(f'\nD5c result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
if cos < 0.95:
print(f' kernel[0,:4]={o_kernel[0,:4].tolist()}')
print(f' ref[0,:4]={ref_combined[0,:4].tolist()}')
def test_d5c_with_causal():
"""D5c with causal mask on SWA branch."""
print("\n=== Stage D5c: Fused Sparse+SWA with Causal Mask ===\n")
hd = 64
m = 128
n_comp = 64
n_swa = 128
n_total = n_comp + n_swa
swa_len = 96 # partially filled SWA
scale = 1.0 / math.sqrt(hd)
torch.manual_seed(123)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
attn_sink = torch.tensor([0.3], dtype=torch.float32, device='cuda')
qf = q[:, :, 0]
ref = reference_combined_attention(
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
attn_sink[0].item(), scale, swa_len, is_causal=True
)
# Concatenate KV
k_combined = torch.cat([k_comp, k_swa], dim=0)
v_combined = torch.cat([v_comp, v_swa], dim=0)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaKernel(
head_dim=hd,
s_k=n_total,
normalize=False,
apply_swa_mask=True,
is_causal=True,
n_comp=n_comp,
)
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
def to_cute(t):
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
mQ = to_cute(q)
mK = to_cute(k_combined)
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
mC = to_cute(c_out)
mLSE = to_cute(lse_out)
print('Compiling D5c kernel (causal + sink bias)...', flush=True)
compiled = cute.compile(
kernel, mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=attn_sink,
)
compiled(
mQ, mK, mV, mC, stream, mLSE,
swa_len=swa_len, sink_bias=attn_sink,
)
torch.cuda.synchronize()
o_kernel = c_out[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
o_kernel.flatten().unsqueeze(0),
ref.flatten().unsqueeze(0).float()
).item()
max_abs = (o_kernel - ref.float()).abs().max().item()
status = "PASS" if cos >= 0.95 else "FAIL"
print(f'D5c causal result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
if __name__ == '__main__':
test_d5c_combined()
test_d5c_with_causal()