D5c: add sink bias (attn_sink) logit modification to FMHA kernel
- Add n_comp parameter: compressed KV length, sink bias applies to positions >= n_comp - Add sink_bias parameter: per-head FP32 logit bias for SWA positions - D3 mask updated: kv_pos >= n_comp + swa_len (backward compatible when n_comp=0) - D4 causal mask updated: compare SWA-relative position (kv_pos - n_comp) with m_coord - Mathematical insight: sink merge = single softmax over [S_comp, S_swa + attn_sink] - Add test_d5c_fused.py with combined KV + sink bias test
This commit is contained in:
@@ -16,7 +16,11 @@ import math
|
||||
|
||||
|
||||
class FmhaKernel:
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False):
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None):
|
||||
# D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to
|
||||
# positions >= n_comp. D3/D4 masks also only apply to SWA region.
|
||||
# When n_comp is None or 0, no sink bias is applied (backward compatible).
|
||||
self.n_comp = n_comp if n_comp is not None else 0
|
||||
self.head_dim = head_dim
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
@@ -105,7 +109,7 @@ class FmhaKernel:
|
||||
cute.size_in_bytes(self.q_dtype, v_s)) * cta
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None):
|
||||
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
@@ -138,12 +142,21 @@ class FmhaKernel:
|
||||
swa_len = Int32(2147483647)
|
||||
else:
|
||||
swa_len = Int32(swa_len)
|
||||
# D5c: sink_bias is a per-head FP32 logit bias applied to SWA positions.
|
||||
# When None, pass 0.0 (no bias). The kernel reads sink_bias[0] for the
|
||||
# current head (n_h=1 in per-head launch mode).
|
||||
if const_expr(sink_bias is None):
|
||||
# D5c: sink_bias not provided. Create a dummy tensor pointing to valid memory.
|
||||
# Never actually read (const_expr(self.n_comp > 0) guards the read).
|
||||
sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,)))
|
||||
else:
|
||||
sink_bias = ct.from_dlpack(sink_bias).mark_layout_dynamic(leading_dim=ct.get_leading_dim(sink_bias))
|
||||
# Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension
|
||||
# For single-head (n_h=1): grid=(1,1,1) — backward compatible
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len):
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx,_,_ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
@@ -411,29 +424,44 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# D3/D4: In-kernel logit masking.
|
||||
# After loading S from TMEM, mask invalid positions to -inf.
|
||||
# D3/D4/D5c: In-kernel logit modification.
|
||||
# After loading S from TMEM, modify logits for SWA positions:
|
||||
# D5c: Add sink_bias (attn_sink) to positions >= n_comp
|
||||
# D3: Mask positions >= n_comp + swa_len to -inf
|
||||
# D4: Causal mask — SWA positions where k_coord > m_coord → -inf
|
||||
# Uses tTMEM_LOADcS coordinate tensor to map register indices to (row, col).
|
||||
# D3: SWA mask — positions >= swa_lens[batch_idx] → -inf
|
||||
# D4: Causal mask — positions where k_coord > m_coord → -inf
|
||||
# Both use the same coordinate mapping from tTMEM_LOADcS.
|
||||
# For kt > 0, absolute KV pos = kt*128 + k_coord.
|
||||
if const_expr(self.apply_swa_mask or self.is_causal):
|
||||
if const_expr(self.apply_swa_mask or self.is_causal or self.n_comp > 0):
|
||||
kt_offset = Int32(kt * 128) # KV position offset for this tile
|
||||
# Iterate using same coordinate indexing as SMEM-P path
|
||||
# D5c: Read sink bias once (same for all positions in this head).
|
||||
# Define unconditionally for CuTeDSL scoping (used when n_comp > 0).
|
||||
sink_val = Float32(0.0)
|
||||
if const_expr(self.n_comp > 0):
|
||||
sink_val = mSinkBias[Int32(0)]
|
||||
for j0 in range(32):
|
||||
for j1 in range(4):
|
||||
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
|
||||
m_coord = coord[0] # query row position
|
||||
k_coord = coord[1] # position within this KV tile
|
||||
kv_pos = kt_offset + k_coord # absolute KV position
|
||||
# D5c: Add sink bias to SWA positions (>= n_comp)
|
||||
if const_expr(self.n_comp > 0):
|
||||
if kv_pos >= Int32(self.n_comp):
|
||||
tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val
|
||||
# D3: SWA length mask
|
||||
should_mask = Boolean(0)
|
||||
if const_expr(self.apply_swa_mask):
|
||||
if kv_pos >= swa_len:
|
||||
# SWA length applies relative to the SWA region start (n_comp)
|
||||
# kv_pos >= n_comp + swa_len means the SWA position >= swa_len
|
||||
if kv_pos >= Int32(self.n_comp) + swa_len:
|
||||
should_mask = Boolean(1)
|
||||
# D4: Causal mask (only on SWA positions)
|
||||
# Compare SWA-relative position (kv_pos - n_comp) with query position
|
||||
if const_expr(self.is_causal):
|
||||
if k_coord > m_coord:
|
||||
should_mask = Boolean(1)
|
||||
if kv_pos >= Int32(self.n_comp):
|
||||
swa_pos = kv_pos - Int32(self.n_comp)
|
||||
if swa_pos > m_coord:
|
||||
should_mask = Boolean(1)
|
||||
if should_mask:
|
||||
tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf
|
||||
|
||||
|
||||
284
tests/unit/test_d5c_fused.py
Normal file
284
tests/unit/test_d5c_fused.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
FMHA D5c: Fused sparse + SWA attention via combined KV + sink bias.
|
||||
|
||||
Mathematical insight: the sink merge is equivalent to a single attention
|
||||
pass over the concatenated KV with a logit bias (attn_sink) applied to
|
||||
the SWA portion. No two passes needed, no merge epilogue.
|
||||
|
||||
S = [S_comp, S_swa + attn_sink]
|
||||
O = softmax(S) @ [V_comp; V_swa]
|
||||
|
||||
This is identical to:
|
||||
O = (exp(lse_sparse)*O_sparse + exp(sink)*exp(lse_swa)*O_swa)
|
||||
/ (exp(lse_sparse) + exp(sink)*exp(lse_swa))
|
||||
|
||||
The kernel changes are minimal:
|
||||
1. K = [compressed_K; swa_K], V = [compressed_V; swa_V]
|
||||
2. n_comp = length of compressed KV (sink bias applies to positions >= n_comp)
|
||||
3. attn_sink = per-head logit bias for SWA positions
|
||||
4. D3/D4 masking applies to SWA region (positions >= n_comp)
|
||||
|
||||
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5c_fused.py
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa,
|
||||
attn_sink, scale, swa_len, is_causal=False):
|
||||
"""FP32 reference: single softmax over combined KV with sink bias on SWA."""
|
||||
m, hd = q.shape
|
||||
n_comp = k_comp.shape[0]
|
||||
n_swa = k_swa.shape[0]
|
||||
|
||||
# Concatenate KV
|
||||
k_combined = torch.cat([k_comp, k_swa], dim=0) # (n_comp + n_swa, hd)
|
||||
v_combined = torch.cat([v_comp, v_swa], dim=0)
|
||||
|
||||
# Compute combined logits
|
||||
scores = torch.matmul(q.float(), k_combined.float().T) * scale # (m, n_comp + n_swa)
|
||||
|
||||
# Add sink bias to SWA positions
|
||||
scores[:, n_comp:] += attn_sink
|
||||
|
||||
# D3: SWA length mask (only SWA region)
|
||||
if swa_len < n_swa:
|
||||
scores[:, n_comp + swa_len:] = float('-inf')
|
||||
|
||||
# D4: causal mask (only SWA region)
|
||||
if is_causal:
|
||||
# Within SWA region: mask k_coord > m_coord
|
||||
for i in range(m):
|
||||
for j in range(n_swa):
|
||||
if j > i: # k_coord > m_coord
|
||||
scores[i, n_comp + j] = float('-inf')
|
||||
|
||||
# Softmax + PV
|
||||
max_s = scores.max(dim=-1, keepdim=True).values
|
||||
exp_s = (scores - max_s).exp()
|
||||
sum_s = exp_s.sum(dim=-1, keepdim=True).clamp(min=1e-30)
|
||||
p = exp_s / sum_s
|
||||
o = torch.matmul(p, v_combined.float())
|
||||
return o.to(torch.bfloat16)
|
||||
|
||||
|
||||
def reference_sink_merge(q, k_comp, v_comp, k_swa, v_swa,
|
||||
attn_sink, scale, swa_len, is_causal=False):
|
||||
"""FP32 reference: separate attention + sink merge (original D5b formula)."""
|
||||
m, hd = q.shape
|
||||
n_comp = k_comp.shape[0]
|
||||
n_swa = k_swa.shape[0]
|
||||
|
||||
# Compressed KV attention (no mask)
|
||||
attn_comp = torch.matmul(q.float(), k_comp.float().T) * scale
|
||||
o_norm_comp = torch.softmax(attn_comp, dim=-1) @ v_comp.float()
|
||||
lse_comp = torch.logsumexp(attn_comp, dim=-1, keepdim=True) # (m, 1)
|
||||
|
||||
# SWA KV attention (with swa_len mask)
|
||||
attn_swa = torch.matmul(q.float(), k_swa.float().T) * scale
|
||||
if swa_len < n_swa:
|
||||
attn_swa[:, swa_len:] = float('-inf')
|
||||
if is_causal:
|
||||
for i in range(m):
|
||||
for j in range(n_swa):
|
||||
if j > i:
|
||||
attn_swa[i, j] = float('-inf')
|
||||
o_norm_swa = torch.softmax(attn_swa, dim=-1) @ v_swa.float()
|
||||
lse_swa = torch.logsumexp(attn_swa, dim=-1, keepdim=True)
|
||||
|
||||
# Sink merge (normalized formula)
|
||||
exp_sink = attn_sink.exp()
|
||||
numerator = lse_comp.exp() * o_norm_comp + exp_sink * lse_swa.exp() * o_norm_swa
|
||||
denominator = (lse_comp.exp() + exp_sink * lse_swa.exp()).clamp(min=1e-30)
|
||||
o = numerator / denominator
|
||||
return o.to(torch.bfloat16)
|
||||
|
||||
|
||||
def test_d5c_combined():
|
||||
print("=== Stage D5c: Fused Sparse+SWA via Combined KV + Sink Bias ===\n")
|
||||
|
||||
hd = 64
|
||||
m = 128 # query rows
|
||||
n_comp = 128 # compressed KV length
|
||||
n_swa = 128 # SWA window length
|
||||
n_total = n_comp + n_swa # combined KV length = 256
|
||||
swa_len = 64 # actual SWA fill
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
torch.manual_seed(42)
|
||||
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# Sink weight (in log domain, per-head). n_h=1 so it's a scalar.
|
||||
attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda')
|
||||
|
||||
# === FP32 References ===
|
||||
qf = q[:, :, 0]
|
||||
|
||||
# Reference 1: Combined softmax with sink bias (our kernel's approach)
|
||||
ref_combined = reference_combined_attention(
|
||||
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
|
||||
attn_sink[0].item(), scale, swa_len
|
||||
)
|
||||
|
||||
# Reference 2: Separate attention + sink merge (original D5b formula)
|
||||
ref_merge = reference_sink_merge(
|
||||
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
|
||||
attn_sink[0].item(), scale, swa_len
|
||||
)
|
||||
|
||||
# Verify the two references agree
|
||||
cos_ref = torch.nn.functional.cosine_similarity(
|
||||
ref_combined.flatten().unsqueeze(0).float(),
|
||||
ref_merge.flatten().unsqueeze(0).float()
|
||||
).item()
|
||||
print(f"Reference: combined softmax vs sink merge cos = {cos_ref:.6f}")
|
||||
assert cos_ref > 0.999, f"References don't match: cos={cos_ref}"
|
||||
|
||||
# === Kernel ===
|
||||
# Concatenate KV for the kernel
|
||||
k_combined = torch.cat([k_comp, k_swa], dim=0) # (n_total, hd, 1)
|
||||
v_combined = torch.cat([v_comp, v_swa], dim=0) # (n_total, hd)
|
||||
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
kernel = FmhaKernel(
|
||||
head_dim=hd,
|
||||
s_k=n_total, # combined KV length
|
||||
normalize=False, # D5a: emit un-normalized O + LSE
|
||||
apply_swa_mask=True, # D3: mask SWA positions
|
||||
is_causal=False, # D4: no causal mask for this test
|
||||
n_comp=n_comp, # D5c: compressed KV length (sink bias starts here)
|
||||
)
|
||||
|
||||
# Allocate output
|
||||
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
# Prepare CuTe tensors
|
||||
def to_cute(t):
|
||||
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
|
||||
|
||||
mQ = to_cute(q)
|
||||
mK = to_cute(k_combined)
|
||||
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
|
||||
mC = to_cute(c_out)
|
||||
mLSE = to_cute(lse_out)
|
||||
|
||||
# Compile
|
||||
print('Compiling D5c kernel (combined KV + sink bias)...', flush=True)
|
||||
compiled = cute.compile(
|
||||
kernel, mQ, mK, mV, mC, stream, mLSE,
|
||||
swa_len=swa_len, sink_bias=attn_sink,
|
||||
)
|
||||
|
||||
# Run
|
||||
print('Running D5c kernel...', flush=True)
|
||||
compiled(
|
||||
mQ, mK, mV, mC, stream, mLSE,
|
||||
swa_len=swa_len, sink_bias=attn_sink,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Check results
|
||||
o_kernel = c_out[:, :, 0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o_kernel.flatten().unsqueeze(0),
|
||||
ref_combined.flatten().unsqueeze(0).float()
|
||||
).item()
|
||||
max_abs = (o_kernel - ref_combined.float()).abs().max().item()
|
||||
|
||||
status = "PASS" if cos >= 0.95 else "FAIL"
|
||||
print(f'\nD5c result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
|
||||
|
||||
if cos < 0.95:
|
||||
print(f' kernel[0,:4]={o_kernel[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref_combined[0,:4].tolist()}')
|
||||
|
||||
|
||||
def test_d5c_with_causal():
|
||||
"""D5c with causal mask on SWA branch."""
|
||||
print("\n=== Stage D5c: Fused Sparse+SWA with Causal Mask ===\n")
|
||||
|
||||
hd = 64
|
||||
m = 128
|
||||
n_comp = 64
|
||||
n_swa = 128
|
||||
n_total = n_comp + n_swa
|
||||
swa_len = 96 # partially filled SWA
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
torch.manual_seed(123)
|
||||
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
attn_sink = torch.tensor([0.3], dtype=torch.float32, device='cuda')
|
||||
qf = q[:, :, 0]
|
||||
|
||||
ref = reference_combined_attention(
|
||||
qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa,
|
||||
attn_sink[0].item(), scale, swa_len, is_causal=True
|
||||
)
|
||||
|
||||
# Concatenate KV
|
||||
k_combined = torch.cat([k_comp, k_swa], dim=0)
|
||||
v_combined = torch.cat([v_comp, v_swa], dim=0)
|
||||
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
kernel = FmhaKernel(
|
||||
head_dim=hd,
|
||||
s_k=n_total,
|
||||
normalize=False,
|
||||
apply_swa_mask=True,
|
||||
is_causal=True,
|
||||
n_comp=n_comp,
|
||||
)
|
||||
|
||||
c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
def to_cute(t):
|
||||
return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t))
|
||||
|
||||
mQ = to_cute(q)
|
||||
mK = to_cute(k_combined)
|
||||
mV = to_cute(v_combined.unsqueeze(-1).contiguous())
|
||||
mC = to_cute(c_out)
|
||||
mLSE = to_cute(lse_out)
|
||||
|
||||
print('Compiling D5c kernel (causal + sink bias)...', flush=True)
|
||||
compiled = cute.compile(
|
||||
kernel, mQ, mK, mV, mC, stream, mLSE,
|
||||
swa_len=swa_len, sink_bias=attn_sink,
|
||||
)
|
||||
compiled(
|
||||
mQ, mK, mV, mC, stream, mLSE,
|
||||
swa_len=swa_len, sink_bias=attn_sink,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
o_kernel = c_out[:, :, 0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o_kernel.flatten().unsqueeze(0),
|
||||
ref.flatten().unsqueeze(0).float()
|
||||
).item()
|
||||
max_abs = (o_kernel - ref.float()).abs().max().item()
|
||||
|
||||
status = "PASS" if cos >= 0.95 else "FAIL"
|
||||
print(f'D5c causal result: cos {cos:.6f} max_abs {max_abs:.4f} {status}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_d5c_combined()
|
||||
test_d5c_with_causal()
|
||||
Reference in New Issue
Block a user