"""DSV4 Attention kernels — public integration API. ==================================================================== STATUS: SKELETON — not yet connected to model ==================================================================== These functions define the API that AttentionSubBlock will call. They're correct in structure but depend on: 1. LayerCacheHandle being fully implemented (gather_compressed_kv, etc.) 2. The production FMHA wrapper supporting sink_bias and n_comp 3. Custom op registration for torch.compile compatibility See ROADMAP.md Priority 5 for the full Stage E checklist. ==================================================================== These functions bridge the model's AttentionSubBlock to the production FMHA kernel wrapper. Each function handles the cache → dense-tensor materialization that the kernel requires. The model's attention layer calls these after: 1. Projection (q_down, q_up, kv_down) 2. RoPE application 3. Compression + cache writes 4. Indexer + top-k (CSA only) These functions handle: - Gathering sparse/dense KV from cache into dense tensors - Calling the production FMHA wrapper - Returning attention output for inverse RoPE + wo_a/wo_b """ from dsv4.kernels.attention.production import dsv4_attention import torch from typing import Optional, TYPE_CHECKING if TYPE_CHECKING: from dsv4.cache.handle import LayerCacheHandle def sparse_fmha_with_swa( q: torch.Tensor, # (T, n_h * hd) BF16, post-RoPE cache: "LayerCacheHandle", # provides compressed + SWA KV selected_indices: torch.Tensor, # (T, top_k) int64 — which compressed blocks sink_logits: Optional[torch.Tensor] = None, # (n_h,) FP32 sliding_window: int = 128, ) -> torch.Tensor: """CSA attention: sparse top-k compressed KV + sliding window, fused sink merge. Gathers the top-k compressed KV blocks + SWA window into a contiguous tensor, then calls the production FMHA with sink bias. Args: q: (T, n_h * hd) BF16 query (post-RoPE, pre-reshape) cache: LayerCacheHandle with CSA compressed entries + SWA window selected_indices: (T, top_k) int64 block indices from the indexer sink_logits: (n_h,) FP32 per-head sink bias sliding_window: SWA window length Returns: (T, n_h * hd) BF16 attention output (pre inverse-RoPE) """ # Reshape q to (n_h, T, hd) n_h_and_hd = q.shape[-1] # n_h and hd come from the cache's config n_h = cache.num_query_heads hd = cache.head_dim T = q.shape[0] q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd) # Gather compressed KV for the selected blocks # The cache handle provides the materialized dense KV from paged pool k_compressed, v_compressed = cache.gather_compressed_kv(selected_indices) # k_compressed: (1, n_comp_kv, hd) or (n_kv, n_comp_kv, hd) # v_compressed: same shape # Gather SWA window KV k_swa, v_swa = cache.gather_swa_kv() # k_swa: (1, swa_len, hd), v_swa: same # Concatenate: [compressed, SWA] — single softmax (D5c insight) k_full = torch.cat([k_compressed, k_swa], dim=-2) # (1, n_comp+swa_len, hd) v_full = torch.cat([v_compressed, v_swa], dim=-2) # n_comp = compressed KV length (for sink bias offset) n_comp = k_compressed.shape[-2] # Call production attention — MQA (n_kv=1 for DSV4) output = dsv4_attention( q_heads, k_full, v_full, swa_len=sliding_window, is_causal=True, n_comp=n_comp, sink_bias=sink_logits, ) # (n_h, T, hd) # Reshape back to (T, n_h * hd) return output.permute(1, 0, 2).reshape(T, n_h * hd) def dense_fmha_with_swa( q: torch.Tensor, cache: "LayerCacheHandle", sink_logits: Optional[torch.Tensor] = None, sliding_window: int = 128, ) -> torch.Tensor: """HCA attention: dense over all compressed KV + SWA window, fused sink merge. No indexer — all compressed entries are attended (m'=128 compression means the sequence is very short). Args: q: (T, n_h * hd) BF16 query cache: LayerCacheHandle with HCA compressed entries + SWA window sink_logits: (n_h,) FP32 per-head sink bias sliding_window: SWA window length Returns: (T, n_h * hd) BF16 attention output """ n_h = cache.num_query_heads hd = cache.head_dim T = q.shape[0] q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # Dense: gather ALL compressed KV (no indexer needed) k_compressed, v_compressed = cache.gather_all_compressed_kv() k_swa, v_swa = cache.gather_swa_kv() k_full = torch.cat([k_compressed, k_swa], dim=-2) v_full = torch.cat([v_compressed, v_swa], dim=-2) n_comp = k_compressed.shape[-2] output = dsv4_attention( q_heads, k_full, v_full, swa_len=sliding_window, is_causal=True, n_comp=n_comp, sink_bias=sink_logits, ) return output.permute(1, 0, 2).reshape(T, n_h * hd) def swa_only_fmha( q: torch.Tensor, cache: "LayerCacheHandle", sink_logits: Optional[torch.Tensor] = None, sliding_window: int = 128, ) -> torch.Tensor: """SWA-only attention: pure local attention over the sliding window. No compression branch, no indexer. Used for the first two layers of the Flash variant. Args: q: (T, n_h * hd) BF16 query cache: LayerCacheHandle with SWA window sink_logits: (n_h,) FP32 per-head sink bias sliding_window: SWA window length Returns: (T, n_h * hd) BF16 attention output """ n_h = cache.num_query_heads hd = cache.head_dim T = q.shape[0] q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) k_swa, v_swa = cache.gather_swa_kv() # No n_comp (no compressed branch), no sink bias offset output = dsv4_attention( q_heads, k_swa, v_swa, swa_len=sliding_window, is_causal=True, n_comp=0, sink_bias=sink_logits, ) return output.permute(1, 0, 2).reshape(T, n_h * hd)