"""State cache: SWA window + uncompressed tail buffer. One slot per active request. Slot index is fixed for a request's lifetime — the manager hands out slot indices at request admission and reclaims them at completion. Per paper §3.5.1: SWA and tail tokens are state-space-like — they depend only on the current position, not on a paged history. No block table; a flat [max_requests, ...] tensor. """ from __future__ import annotations import torch from dsv4.cache.schema import LayerCacheSchema, AttentionType class StateCachePool: """Per-layer state cache (SWA window + uncompressed tail). Storage layout per slot: swa_fp8: [n_win, head_dim - rope_dim] FP8 raw KV in window swa_rope: [n_win, rope_dim] BF16 RoPE'd half swa_inv: [n_win] FP32 per-token inv scale swa_pos: [n_win] int32 — absolute position of each window slot (-1 if invalid) tail_ka: [tail_size, head_dim] BF16 raw — pending tokens not yet compressed tail_za: [tail_size, head_dim] BF16 — compression weights (Z stream for CSA, single Z for HCA) tail_kb: [tail_size, head_dim] BF16 — second stream (CSA only) tail_zb: [tail_size, head_dim] BF16 — second Z stream (CSA only) tail_len: scalar int32 — how many tail entries are valid """ def __init__( self, schema: LayerCacheSchema, max_requests: int, device: str = "cuda", ): self.schema = schema self.max_requests = max_requests self.device = device mr = max_requests nw = schema.swa_window_size hd = schema.entry_head_dim rd = schema.rope_dim fp8 = hd - rd # SWA window — circular within each slot. Layer's attention # kernel uses swa_pos to mask invalid entries. self.swa_fp8 = torch.zeros((mr, nw, fp8), dtype=torch.uint8, device=device) self.swa_rope = torch.zeros((mr, nw, rd), dtype=torch.bfloat16, device=device) self.swa_inv = torch.ones((mr, nw), dtype=torch.float32, device=device) self.swa_pos = torch.full((mr, nw), -1, dtype=torch.int32, device=device) # Next write position within each slot's ring buffer. self.swa_head = torch.zeros((mr,), dtype=torch.int32, device=device) # Tail buffer — only non-empty for compressed layers. tail = schema.tail_buffer_size if tail > 0: # For CSA we need two streams (Ca/Cb, Za/Zb) since the # compressor uses overlapping pairs. HCA only needs one # stream. Store both; HCA leaves the b-channel zero. self.tail_ka = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device) self.tail_za = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device) if schema.attn_type == AttentionType.CSA: self.tail_kb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device) self.tail_zb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device) else: self.tail_kb = None self.tail_zb = None self.tail_len = torch.zeros((mr,), dtype=torch.int32, device=device) else: self.tail_ka = self.tail_kb = None self.tail_za = self.tail_zb = None self.tail_len = None def reset_slot(self, slot: int) -> None: """Clear a request's state after completion.""" self.swa_pos[slot].fill_(-1) self.swa_head[slot] = 0 if self.tail_len is not None: self.tail_len[slot] = 0 def memory_bytes(self) -> int: """Total GPU memory used by this pool.""" total = 0 for name in ("swa_fp8", "swa_rope", "swa_inv", "swa_pos", "swa_head", "tail_ka", "tail_za", "tail_kb", "tail_zb", "tail_len"): t = getattr(self, name) if t is not None: total += t.numel() * t.element_size() return total