"""LayerCacheHandle — typed per-call view onto one layer's cache. Constructed by KVCacheManager.acquire() once per layer per forward. Holds tensor references and integer indices; no allocation. Methods expose the operations AttentionSubBlock needs without exposing the underlying storage layout. """ from __future__ import annotations from dataclasses import dataclass from typing import Optional, TYPE_CHECKING import torch if TYPE_CHECKING: from dsv4.cache.paged_cache import PagedKVPool from dsv4.cache.state_cache import StateCachePool @dataclass class LayerCacheHandle: """Read/write interface for one layer's cache. The fields are the resolved indices and tensor refs for THIS call's batch of requests. AttentionSubBlock never sees raw pool tensors. """ # Pool references (shared across handles — never mutated). paged: Optional["PagedKVPool"] state: "StateCachePool" # Per-call indices. request_slots: torch.Tensor # [batch] int32 — state cache slot per request positions: torch.Tensor # [tokens] int32 — absolute position per token request_ids: torch.Tensor # [tokens] int32 — which request each token belongs to # Block table for the classical pool (None for SWA-only layers). # Shape: [batch, max_logical_blocks] int32. -1 padding for unused entries. block_table: Optional[torch.Tensor] # Number of valid blocks per request (excludes padding). block_lens: Optional[torch.Tensor] # ------------------------------------------------------------------ # Methods called by AttentionSubBlock # ------------------------------------------------------------------ def write_swa( self, raw_kv: torch.Tensor, # (T, head_dim) BF16 ) -> None: """Write raw KV into the SWA ring buffer AND tail compression buffer. Both regions get the same tokens — SWA consumes the last n_win, the tail accumulates until it can flush. """ from dsv4.kernels.cache.append_swa import append_swa_kernel append_swa_kernel( raw_kv=raw_kv, request_slots=self.request_slots, positions=self.positions, swa_fp8=self.state.swa_fp8, swa_rope=self.state.swa_rope, swa_inv=self.state.swa_inv, swa_pos=self.state.swa_pos, swa_head=self.state.swa_head, rope_dim=self.state.schema.rope_dim, ) def flush_compression( self, compressed: torch.Tensor, # (T_flush, head_dim) BF16 — newly produced indexer_keys: Optional[torch.Tensor] = None, ) -> None: """Promote pending tail tokens into the classical pool. Called by the compressor when the tail buffer has enough tokens. Allocates a new block if the latest block is full. Block allocation requires going outside the captured graph — in a fully-captured decode this is rare (once per m or m' tokens), so we make it explicit. The manager has the contract. """ raise NotImplementedError("see kernels/cache/flush_compression.py") def read_swa_view(self) -> "SWAView": """Return a typed view of the SWA window for this batch.""" return SWAView( fp8=self.state.swa_fp8, rope=self.state.swa_rope, inv_scale=self.state.swa_inv, positions=self.state.swa_pos, head=self.state.swa_head, slots=self.request_slots, ) def read_classical_view(self) -> "ClassicalView": """Return a typed view of compressed entries for this batch.""" assert self.paged is not None, "SWA-only layers have no classical cache" return ClassicalView( entries_fp8=self.paged.entries_fp8, entries_rope=self.paged.entries_rope, inv_scale=self.paged.inv_scale, block_table=self.block_table, block_lens=self.block_lens, ) def read_indexer_view(self) -> "IndexerView": """CSA-only. Returns FP4 indexer keys with their scales.""" assert self.paged is not None and self.paged.indexer_keys_fp4 is not None return IndexerView( keys_fp4=self.paged.indexer_keys_fp4, scale=self.paged.indexer_scale, global_scale=self.paged.indexer_global_scale, block_table=self.block_table, block_lens=self.block_lens, ) # Typed views — simple dataclasses, no logic. The FMHA / indexer / SWA # kernels accept these to keep their signatures clean. @dataclass class SWAView: fp8: torch.Tensor rope: torch.Tensor inv_scale: torch.Tensor positions: torch.Tensor head: torch.Tensor slots: torch.Tensor @dataclass class ClassicalView: entries_fp8: torch.Tensor entries_rope: torch.Tensor inv_scale: torch.Tensor block_table: torch.Tensor block_lens: torch.Tensor @dataclass class IndexerView: keys_fp4: torch.Tensor scale: torch.Tensor global_scale: torch.Tensor block_table: torch.Tensor block_lens: torch.Tensor