From 9d644349541ae676b34f4f952669652cf97ff7bc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 26 May 2026 14:59:52 +0000 Subject: [PATCH] D5c: add sink bias (attn_sink) logit modification to FMHA kernel - Add n_comp parameter: compressed KV length, sink bias applies to positions >= n_comp - Add sink_bias parameter: per-head FP32 logit bias for SWA positions - D3 mask updated: kv_pos >= n_comp + swa_len (backward compatible when n_comp=0) - D4 causal mask updated: compare SWA-relative position (kv_pos - n_comp) with m_coord - Mathematical insight: sink merge = single softmax over [S_comp, S_swa + attn_sink] - Add test_d5c_fused.py with combined KV + sink bias test --- dsv4/kernels/attention/fmha.py | 56 +++++-- tests/unit/test_d5c_fused.py | 284 +++++++++++++++++++++++++++++++++ 2 files changed, 326 insertions(+), 14 deletions(-) create mode 100644 tests/unit/test_d5c_fused.py diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 4b9c956d..4d9d3d3c 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -16,7 +16,11 @@ import math class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None): + # D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to + # positions >= n_comp. D3/D4 masks also only apply to SWA region. + # When n_comp is None or 0, no sink bias is applied (backward compatible). + self.n_comp = n_comp if n_comp is not None else 0 self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 @@ -105,7 +109,7 @@ class FmhaKernel: cute.size_in_bytes(self.q_dtype, v_s)) * cta @cute.jit - def __call__(self, q, k, v, c, stream, lse=None, swa_len=None): + def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None): self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() @@ -138,12 +142,21 @@ class FmhaKernel: swa_len = Int32(2147483647) else: swa_len = Int32(swa_len) + # D5c: sink_bias is a per-head FP32 logit bias applied to SWA positions. + # When None, pass 0.0 (no bias). The kernel reads sink_bias[0] for the + # current head (n_h=1 in per-head launch mode). + if const_expr(sink_bias is None): + # D5c: sink_bias not provided. Create a dummy tensor pointing to valid memory. + # Never actually read (const_expr(self.n_comp > 0) guards the read). + sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,))) + else: + sink_bias = ct.from_dlpack(sink_bias).mark_layout_dynamic(leading_dim=ct.get_leading_dim(sink_bias)) # Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension # For single-head (n_h=1): grid=(1,1,1) — backward compatible - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream) @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len): + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx,_,_ = cute.arch.thread_idx() if warp_idx == self.tma_warp_id: @@ -411,29 +424,44 @@ class FmhaKernel: cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) cute.arch.fence_view_async_tmem_load() - # D3/D4: In-kernel logit masking. - # After loading S from TMEM, mask invalid positions to -inf. + # D3/D4/D5c: In-kernel logit modification. + # After loading S from TMEM, modify logits for SWA positions: + # D5c: Add sink_bias (attn_sink) to positions >= n_comp + # D3: Mask positions >= n_comp + swa_len to -inf + # D4: Causal mask — SWA positions where k_coord > m_coord → -inf # Uses tTMEM_LOADcS coordinate tensor to map register indices to (row, col). - # D3: SWA mask — positions >= swa_lens[batch_idx] → -inf - # D4: Causal mask — positions where k_coord > m_coord → -inf - # Both use the same coordinate mapping from tTMEM_LOADcS. # For kt > 0, absolute KV pos = kt*128 + k_coord. - if const_expr(self.apply_swa_mask or self.is_causal): + if const_expr(self.apply_swa_mask or self.is_causal or self.n_comp > 0): kt_offset = Int32(kt * 128) # KV position offset for this tile - # Iterate using same coordinate indexing as SMEM-P path + # D5c: Read sink bias once (same for all positions in this head). + # Define unconditionally for CuTeDSL scoping (used when n_comp > 0). + sink_val = Float32(0.0) + if const_expr(self.n_comp > 0): + sink_val = mSinkBias[Int32(0)] for j0 in range(32): for j1 in range(4): coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] m_coord = coord[0] # query row position k_coord = coord[1] # position within this KV tile kv_pos = kt_offset + k_coord # absolute KV position + # D5c: Add sink bias to SWA positions (>= n_comp) + if const_expr(self.n_comp > 0): + if kv_pos >= Int32(self.n_comp): + tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val + # D3: SWA length mask should_mask = Boolean(0) if const_expr(self.apply_swa_mask): - if kv_pos >= swa_len: + # SWA length applies relative to the SWA region start (n_comp) + # kv_pos >= n_comp + swa_len means the SWA position >= swa_len + if kv_pos >= Int32(self.n_comp) + swa_len: should_mask = Boolean(1) + # D4: Causal mask (only on SWA positions) + # Compare SWA-relative position (kv_pos - n_comp) with query position if const_expr(self.is_causal): - if k_coord > m_coord: - should_mask = Boolean(1) + if kv_pos >= Int32(self.n_comp): + swa_pos = kv_pos - Int32(self.n_comp) + if swa_pos > m_coord: + should_mask = Boolean(1) if should_mask: tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf diff --git a/tests/unit/test_d5c_fused.py b/tests/unit/test_d5c_fused.py new file mode 100644 index 00000000..a9335f93 --- /dev/null +++ b/tests/unit/test_d5c_fused.py @@ -0,0 +1,284 @@ +""" +FMHA D5c: Fused sparse + SWA attention via combined KV + sink bias. + +Mathematical insight: the sink merge is equivalent to a single attention +pass over the concatenated KV with a logit bias (attn_sink) applied to +the SWA portion. No two passes needed, no merge epilogue. + + S = [S_comp, S_swa + attn_sink] + O = softmax(S) @ [V_comp; V_swa] + +This is identical to: + O = (exp(lse_sparse)*O_sparse + exp(sink)*exp(lse_swa)*O_swa) + / (exp(lse_sparse) + exp(sink)*exp(lse_swa)) + +The kernel changes are minimal: +1. K = [compressed_K; swa_K], V = [compressed_V; swa_V] +2. n_comp = length of compressed KV (sink bias applies to positions >= n_comp) +3. attn_sink = per-head logit bias for SWA positions +4. D3/D4 masking applies to SWA region (positions >= n_comp) + +Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d5c_fused.py +""" +import torch +import math +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + + +def reference_combined_attention(q, k_comp, v_comp, k_swa, v_swa, + attn_sink, scale, swa_len, is_causal=False): + """FP32 reference: single softmax over combined KV with sink bias on SWA.""" + m, hd = q.shape + n_comp = k_comp.shape[0] + n_swa = k_swa.shape[0] + + # Concatenate KV + k_combined = torch.cat([k_comp, k_swa], dim=0) # (n_comp + n_swa, hd) + v_combined = torch.cat([v_comp, v_swa], dim=0) + + # Compute combined logits + scores = torch.matmul(q.float(), k_combined.float().T) * scale # (m, n_comp + n_swa) + + # Add sink bias to SWA positions + scores[:, n_comp:] += attn_sink + + # D3: SWA length mask (only SWA region) + if swa_len < n_swa: + scores[:, n_comp + swa_len:] = float('-inf') + + # D4: causal mask (only SWA region) + if is_causal: + # Within SWA region: mask k_coord > m_coord + for i in range(m): + for j in range(n_swa): + if j > i: # k_coord > m_coord + scores[i, n_comp + j] = float('-inf') + + # Softmax + PV + max_s = scores.max(dim=-1, keepdim=True).values + exp_s = (scores - max_s).exp() + sum_s = exp_s.sum(dim=-1, keepdim=True).clamp(min=1e-30) + p = exp_s / sum_s + o = torch.matmul(p, v_combined.float()) + return o.to(torch.bfloat16) + + +def reference_sink_merge(q, k_comp, v_comp, k_swa, v_swa, + attn_sink, scale, swa_len, is_causal=False): + """FP32 reference: separate attention + sink merge (original D5b formula).""" + m, hd = q.shape + n_comp = k_comp.shape[0] + n_swa = k_swa.shape[0] + + # Compressed KV attention (no mask) + attn_comp = torch.matmul(q.float(), k_comp.float().T) * scale + o_norm_comp = torch.softmax(attn_comp, dim=-1) @ v_comp.float() + lse_comp = torch.logsumexp(attn_comp, dim=-1, keepdim=True) # (m, 1) + + # SWA KV attention (with swa_len mask) + attn_swa = torch.matmul(q.float(), k_swa.float().T) * scale + if swa_len < n_swa: + attn_swa[:, swa_len:] = float('-inf') + if is_causal: + for i in range(m): + for j in range(n_swa): + if j > i: + attn_swa[i, j] = float('-inf') + o_norm_swa = torch.softmax(attn_swa, dim=-1) @ v_swa.float() + lse_swa = torch.logsumexp(attn_swa, dim=-1, keepdim=True) + + # Sink merge (normalized formula) + exp_sink = attn_sink.exp() + numerator = lse_comp.exp() * o_norm_comp + exp_sink * lse_swa.exp() * o_norm_swa + denominator = (lse_comp.exp() + exp_sink * lse_swa.exp()).clamp(min=1e-30) + o = numerator / denominator + return o.to(torch.bfloat16) + + +def test_d5c_combined(): + print("=== Stage D5c: Fused Sparse+SWA via Combined KV + Sink Bias ===\n") + + hd = 64 + m = 128 # query rows + n_comp = 128 # compressed KV length + n_swa = 128 # SWA window length + n_total = n_comp + n_swa # combined KV length = 256 + swa_len = 64 # actual SWA fill + scale = 1.0 / math.sqrt(hd) + torch.manual_seed(42) + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda') + v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda') + k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda') + v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda') + + # Sink weight (in log domain, per-head). n_h=1 so it's a scalar. + attn_sink = torch.tensor([0.5], dtype=torch.float32, device='cuda') + + # === FP32 References === + qf = q[:, :, 0] + + # Reference 1: Combined softmax with sink bias (our kernel's approach) + ref_combined = reference_combined_attention( + qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa, + attn_sink[0].item(), scale, swa_len + ) + + # Reference 2: Separate attention + sink merge (original D5b formula) + ref_merge = reference_sink_merge( + qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa, + attn_sink[0].item(), scale, swa_len + ) + + # Verify the two references agree + cos_ref = torch.nn.functional.cosine_similarity( + ref_combined.flatten().unsqueeze(0).float(), + ref_merge.flatten().unsqueeze(0).float() + ).item() + print(f"Reference: combined softmax vs sink merge cos = {cos_ref:.6f}") + assert cos_ref > 0.999, f"References don't match: cos={cos_ref}" + + # === Kernel === + # Concatenate KV for the kernel + k_combined = torch.cat([k_comp, k_swa], dim=0) # (n_total, hd, 1) + v_combined = torch.cat([v_comp, v_swa], dim=0) # (n_total, hd) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + kernel = FmhaKernel( + head_dim=hd, + s_k=n_total, # combined KV length + normalize=False, # D5a: emit un-normalized O + LSE + apply_swa_mask=True, # D3: mask SWA positions + is_causal=False, # D4: no causal mask for this test + n_comp=n_comp, # D5c: compressed KV length (sink bias starts here) + ) + + # Allocate output + c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + # Prepare CuTe tensors + def to_cute(t): + return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t)) + + mQ = to_cute(q) + mK = to_cute(k_combined) + mV = to_cute(v_combined.unsqueeze(-1).contiguous()) + mC = to_cute(c_out) + mLSE = to_cute(lse_out) + + # Compile + print('Compiling D5c kernel (combined KV + sink bias)...', flush=True) + compiled = cute.compile( + kernel, mQ, mK, mV, mC, stream, mLSE, + swa_len=swa_len, sink_bias=attn_sink, + ) + + # Run + print('Running D5c kernel...', flush=True) + compiled( + mQ, mK, mV, mC, stream, mLSE, + swa_len=swa_len, sink_bias=attn_sink, + ) + torch.cuda.synchronize() + + # Check results + o_kernel = c_out[:, :, 0].float() + cos = torch.nn.functional.cosine_similarity( + o_kernel.flatten().unsqueeze(0), + ref_combined.flatten().unsqueeze(0).float() + ).item() + max_abs = (o_kernel - ref_combined.float()).abs().max().item() + + status = "PASS" if cos >= 0.95 else "FAIL" + print(f'\nD5c result: cos {cos:.6f} max_abs {max_abs:.4f} {status}') + + if cos < 0.95: + print(f' kernel[0,:4]={o_kernel[0,:4].tolist()}') + print(f' ref[0,:4]={ref_combined[0,:4].tolist()}') + + +def test_d5c_with_causal(): + """D5c with causal mask on SWA branch.""" + print("\n=== Stage D5c: Fused Sparse+SWA with Causal Mask ===\n") + + hd = 64 + m = 128 + n_comp = 64 + n_swa = 128 + n_total = n_comp + n_swa + swa_len = 96 # partially filled SWA + scale = 1.0 / math.sqrt(hd) + torch.manual_seed(123) + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k_comp = torch.randn(n_comp, hd, 1, dtype=torch.bfloat16, device='cuda') + v_comp = torch.randn(n_comp, hd, dtype=torch.bfloat16, device='cuda') + k_swa = torch.randn(n_swa, hd, 1, dtype=torch.bfloat16, device='cuda') + v_swa = torch.randn(n_swa, hd, dtype=torch.bfloat16, device='cuda') + + attn_sink = torch.tensor([0.3], dtype=torch.float32, device='cuda') + qf = q[:, :, 0] + + ref = reference_combined_attention( + qf, k_comp[:, :, 0], v_comp, k_swa[:, :, 0], v_swa, + attn_sink[0].item(), scale, swa_len, is_causal=True + ) + + # Concatenate KV + k_combined = torch.cat([k_comp, k_swa], dim=0) + v_combined = torch.cat([v_comp, v_swa], dim=0) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + kernel = FmhaKernel( + head_dim=hd, + s_k=n_total, + normalize=False, + apply_swa_mask=True, + is_causal=True, + n_comp=n_comp, + ) + + c_out = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + lse_out = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + def to_cute(t): + return ct.from_dlpack(t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(t)) + + mQ = to_cute(q) + mK = to_cute(k_combined) + mV = to_cute(v_combined.unsqueeze(-1).contiguous()) + mC = to_cute(c_out) + mLSE = to_cute(lse_out) + + print('Compiling D5c kernel (causal + sink bias)...', flush=True) + compiled = cute.compile( + kernel, mQ, mK, mV, mC, stream, mLSE, + swa_len=swa_len, sink_bias=attn_sink, + ) + compiled( + mQ, mK, mV, mC, stream, mLSE, + swa_len=swa_len, sink_bias=attn_sink, + ) + torch.cuda.synchronize() + + o_kernel = c_out[:, :, 0].float() + cos = torch.nn.functional.cosine_similarity( + o_kernel.flatten().unsqueeze(0), + ref.flatten().unsqueeze(0).float() + ).item() + max_abs = (o_kernel - ref.float()).abs().max().item() + + status = "PASS" if cos >= 0.95 else "FAIL" + print(f'D5c causal result: cos {cos:.6f} max_abs {max_abs:.4f} {status}') + + +if __name__ == '__main__': + test_d5c_combined() + test_d5c_with_causal()