143 lines
5.0 KiB
Python
143 lines
5.0 KiB
Python
|
|
"""LayerCacheHandle — typed per-call view onto one layer's cache.
|
||
|
|
|
||
|
|
Constructed by KVCacheManager.acquire() once per layer per forward.
|
||
|
|
Holds tensor references and integer indices; no allocation. Methods
|
||
|
|
expose the operations AttentionSubBlock needs without exposing the
|
||
|
|
underlying storage layout.
|
||
|
|
"""
|
||
|
|
from __future__ import annotations
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Optional, TYPE_CHECKING
|
||
|
|
import torch
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from dsv4.cache.paged_cache import PagedKVPool
|
||
|
|
from dsv4.cache.state_cache import StateCachePool
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class LayerCacheHandle:
|
||
|
|
"""Read/write interface for one layer's cache.
|
||
|
|
|
||
|
|
The fields are the resolved indices and tensor refs for THIS call's
|
||
|
|
batch of requests. AttentionSubBlock never sees raw pool tensors.
|
||
|
|
"""
|
||
|
|
# Pool references (shared across handles — never mutated).
|
||
|
|
paged: Optional["PagedKVPool"]
|
||
|
|
state: "StateCachePool"
|
||
|
|
|
||
|
|
# Per-call indices.
|
||
|
|
request_slots: torch.Tensor # [batch] int32 — state cache slot per request
|
||
|
|
positions: torch.Tensor # [tokens] int32 — absolute position per token
|
||
|
|
request_ids: torch.Tensor # [tokens] int32 — which request each token belongs to
|
||
|
|
|
||
|
|
# Block table for the classical pool (None for SWA-only layers).
|
||
|
|
# Shape: [batch, max_logical_blocks] int32. -1 padding for unused entries.
|
||
|
|
block_table: Optional[torch.Tensor]
|
||
|
|
# Number of valid blocks per request (excludes padding).
|
||
|
|
block_lens: Optional[torch.Tensor]
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Methods called by AttentionSubBlock
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
def write_swa(
|
||
|
|
self,
|
||
|
|
raw_kv: torch.Tensor, # (T, head_dim) BF16
|
||
|
|
) -> None:
|
||
|
|
"""Write raw KV into the SWA ring buffer AND tail compression buffer.
|
||
|
|
|
||
|
|
Both regions get the same tokens — SWA consumes the last n_win,
|
||
|
|
the tail accumulates until it can flush.
|
||
|
|
"""
|
||
|
|
from dsv4.kernels.cache.append_swa import append_swa_kernel
|
||
|
|
append_swa_kernel(
|
||
|
|
raw_kv=raw_kv,
|
||
|
|
request_slots=self.request_slots,
|
||
|
|
positions=self.positions,
|
||
|
|
swa_fp8=self.state.swa_fp8,
|
||
|
|
swa_rope=self.state.swa_rope,
|
||
|
|
swa_inv=self.state.swa_inv,
|
||
|
|
swa_pos=self.state.swa_pos,
|
||
|
|
swa_head=self.state.swa_head,
|
||
|
|
rope_dim=self.state.schema.rope_dim,
|
||
|
|
)
|
||
|
|
|
||
|
|
def flush_compression(
|
||
|
|
self,
|
||
|
|
compressed: torch.Tensor, # (T_flush, head_dim) BF16 — newly produced
|
||
|
|
indexer_keys: Optional[torch.Tensor] = None,
|
||
|
|
) -> None:
|
||
|
|
"""Promote pending tail tokens into the classical pool.
|
||
|
|
|
||
|
|
Called by the compressor when the tail buffer has enough tokens.
|
||
|
|
Allocates a new block if the latest block is full.
|
||
|
|
|
||
|
|
Block allocation requires going outside the captured graph — in
|
||
|
|
a fully-captured decode this is rare (once per m or m' tokens),
|
||
|
|
so we make it explicit. The manager has the contract.
|
||
|
|
"""
|
||
|
|
raise NotImplementedError("see kernels/cache/flush_compression.py")
|
||
|
|
|
||
|
|
def read_swa_view(self) -> "SWAView":
|
||
|
|
"""Return a typed view of the SWA window for this batch."""
|
||
|
|
return SWAView(
|
||
|
|
fp8=self.state.swa_fp8,
|
||
|
|
rope=self.state.swa_rope,
|
||
|
|
inv_scale=self.state.swa_inv,
|
||
|
|
positions=self.state.swa_pos,
|
||
|
|
head=self.state.swa_head,
|
||
|
|
slots=self.request_slots,
|
||
|
|
)
|
||
|
|
|
||
|
|
def read_classical_view(self) -> "ClassicalView":
|
||
|
|
"""Return a typed view of compressed entries for this batch."""
|
||
|
|
assert self.paged is not None, "SWA-only layers have no classical cache"
|
||
|
|
return ClassicalView(
|
||
|
|
entries_fp8=self.paged.entries_fp8,
|
||
|
|
entries_rope=self.paged.entries_rope,
|
||
|
|
inv_scale=self.paged.inv_scale,
|
||
|
|
block_table=self.block_table,
|
||
|
|
block_lens=self.block_lens,
|
||
|
|
)
|
||
|
|
|
||
|
|
def read_indexer_view(self) -> "IndexerView":
|
||
|
|
"""CSA-only. Returns FP4 indexer keys with their scales."""
|
||
|
|
assert self.paged is not None and self.paged.indexer_keys_fp4 is not None
|
||
|
|
return IndexerView(
|
||
|
|
keys_fp4=self.paged.indexer_keys_fp4,
|
||
|
|
scale=self.paged.indexer_scale,
|
||
|
|
global_scale=self.paged.indexer_global_scale,
|
||
|
|
block_table=self.block_table,
|
||
|
|
block_lens=self.block_lens,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# Typed views — simple dataclasses, no logic. The FMHA / indexer / SWA
|
||
|
|
# kernels accept these to keep their signatures clean.
|
||
|
|
@dataclass
|
||
|
|
class SWAView:
|
||
|
|
fp8: torch.Tensor
|
||
|
|
rope: torch.Tensor
|
||
|
|
inv_scale: torch.Tensor
|
||
|
|
positions: torch.Tensor
|
||
|
|
head: torch.Tensor
|
||
|
|
slots: torch.Tensor
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ClassicalView:
|
||
|
|
entries_fp8: torch.Tensor
|
||
|
|
entries_rope: torch.Tensor
|
||
|
|
inv_scale: torch.Tensor
|
||
|
|
block_table: torch.Tensor
|
||
|
|
block_lens: torch.Tensor
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class IndexerView:
|
||
|
|
keys_fp4: torch.Tensor
|
||
|
|
scale: torch.Tensor
|
||
|
|
global_scale: torch.Tensor
|
||
|
|
block_table: torch.Tensor
|
||
|
|
block_lens: torch.Tensor
|