From b4d58df620bb591489a15483cfcbc770c4a552e4 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 00:08:38 +0000 Subject: [PATCH] KV Cache: schema, allocator, pools, manager, append_swa kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation --- dsv4/cache/allocator.py | 56 +++++++ dsv4/cache/handle.py | 142 +++++++++++++++++ dsv4/cache/manager.py | 200 ++++++++++++++++++++++++ dsv4/cache/paged_cache.py | 93 +++++++++++- dsv4/cache/schema.py | 129 ++++++++++++++++ dsv4/cache/state_cache.py | 98 +++++++++++- dsv4/kernels/cache/__init__.py | 1 + dsv4/kernels/cache/append_swa.py | 51 +++++++ dsv4/kernels/cuda/append_swa.cu | 165 ++++++++++++++++++++ tests/unit/test_cache.py | 252 +++++++++++++++++++++++++++++++ 10 files changed, 1183 insertions(+), 4 deletions(-) create mode 100644 dsv4/cache/allocator.py create mode 100644 dsv4/cache/handle.py create mode 100644 dsv4/cache/manager.py create mode 100644 dsv4/cache/schema.py create mode 100644 dsv4/kernels/cache/__init__.py create mode 100644 dsv4/kernels/cache/append_swa.py create mode 100644 dsv4/kernels/cuda/append_swa.cu create mode 100644 tests/unit/test_cache.py diff --git a/dsv4/cache/allocator.py b/dsv4/cache/allocator.py new file mode 100644 index 00000000..acdeb61b --- /dev/null +++ b/dsv4/cache/allocator.py @@ -0,0 +1,56 @@ +"""Fixed-size block allocator for the classical paged KV cache. + +One BlockAllocator per layer per "pool kind" (classical / indexer). +Total blocks are sized at engine startup. Blocks are recycled on +request completion. + +Cudagraph-safety: allocation can't happen inside a captured graph +(allocation rate is per-request not per-token). The contract is: + - acquire() called between graph captures. + - release() called between graph captures. + - read access (via block table) happens INSIDE captured graphs. +""" +from __future__ import annotations +import torch + + +class BlockAllocator: + def __init__( + self, + num_total_blocks: int, + device: str = "cuda", + ): + self.num_total_blocks = num_total_blocks + self.device = device + + # Free-list as a GPU stack: ids[0..top-1] holds free block IDs. + # `top` lives in pinned host memory so we can read it without a + # device sync (it's modified only between graph captures). + self.free_ids = torch.arange( + num_total_blocks, dtype=torch.int32, device=device, + ) + self.top_cpu = torch.tensor([num_total_blocks], dtype=torch.int32, pin_memory=True) + + @property + def num_free(self) -> int: + return int(self.top_cpu[0]) + + def acquire(self, n: int) -> torch.Tensor: + """Return a tensor of `n` block IDs. Called between captures.""" + top = int(self.top_cpu[0]) + if n > top: + raise RuntimeError( + f"KV cache OOM: requested {n} blocks, {top} available " + f"(of {self.num_total_blocks} total)" + ) + new_top = top - n + ids = self.free_ids[new_top:top].clone() # snapshot + self.top_cpu[0] = new_top + return ids + + def release(self, ids: torch.Tensor) -> None: + """Return blocks to the free list. Called between captures.""" + n = ids.numel() + top = int(self.top_cpu[0]) + self.free_ids[top:top + n] = ids.to(device=self.device) + self.top_cpu[0] = top + n diff --git a/dsv4/cache/handle.py b/dsv4/cache/handle.py new file mode 100644 index 00000000..a638cfc9 --- /dev/null +++ b/dsv4/cache/handle.py @@ -0,0 +1,142 @@ +"""LayerCacheHandle — typed per-call view onto one layer's cache. + +Constructed by KVCacheManager.acquire() once per layer per forward. +Holds tensor references and integer indices; no allocation. Methods +expose the operations AttentionSubBlock needs without exposing the +underlying storage layout. +""" +from __future__ import annotations +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING +import torch + +if TYPE_CHECKING: + from dsv4.cache.paged_cache import PagedKVPool + from dsv4.cache.state_cache import StateCachePool + + +@dataclass +class LayerCacheHandle: + """Read/write interface for one layer's cache. + + The fields are the resolved indices and tensor refs for THIS call's + batch of requests. AttentionSubBlock never sees raw pool tensors. + """ + # Pool references (shared across handles — never mutated). + paged: Optional["PagedKVPool"] + state: "StateCachePool" + + # Per-call indices. + request_slots: torch.Tensor # [batch] int32 — state cache slot per request + positions: torch.Tensor # [tokens] int32 — absolute position per token + request_ids: torch.Tensor # [tokens] int32 — which request each token belongs to + + # Block table for the classical pool (None for SWA-only layers). + # Shape: [batch, max_logical_blocks] int32. -1 padding for unused entries. + block_table: Optional[torch.Tensor] + # Number of valid blocks per request (excludes padding). + block_lens: Optional[torch.Tensor] + + # ------------------------------------------------------------------ + # Methods called by AttentionSubBlock + # ------------------------------------------------------------------ + def write_swa( + self, + raw_kv: torch.Tensor, # (T, head_dim) BF16 + ) -> None: + """Write raw KV into the SWA ring buffer AND tail compression buffer. + + Both regions get the same tokens — SWA consumes the last n_win, + the tail accumulates until it can flush. + """ + from dsv4.kernels.cache.append_swa import append_swa_kernel + append_swa_kernel( + raw_kv=raw_kv, + request_slots=self.request_slots, + positions=self.positions, + swa_fp8=self.state.swa_fp8, + swa_rope=self.state.swa_rope, + swa_inv=self.state.swa_inv, + swa_pos=self.state.swa_pos, + swa_head=self.state.swa_head, + rope_dim=self.state.schema.rope_dim, + ) + + def flush_compression( + self, + compressed: torch.Tensor, # (T_flush, head_dim) BF16 — newly produced + indexer_keys: Optional[torch.Tensor] = None, + ) -> None: + """Promote pending tail tokens into the classical pool. + + Called by the compressor when the tail buffer has enough tokens. + Allocates a new block if the latest block is full. + + Block allocation requires going outside the captured graph — in + a fully-captured decode this is rare (once per m or m' tokens), + so we make it explicit. The manager has the contract. + """ + raise NotImplementedError("see kernels/cache/flush_compression.py") + + def read_swa_view(self) -> "SWAView": + """Return a typed view of the SWA window for this batch.""" + return SWAView( + fp8=self.state.swa_fp8, + rope=self.state.swa_rope, + inv_scale=self.state.swa_inv, + positions=self.state.swa_pos, + head=self.state.swa_head, + slots=self.request_slots, + ) + + def read_classical_view(self) -> "ClassicalView": + """Return a typed view of compressed entries for this batch.""" + assert self.paged is not None, "SWA-only layers have no classical cache" + return ClassicalView( + entries_fp8=self.paged.entries_fp8, + entries_rope=self.paged.entries_rope, + inv_scale=self.paged.inv_scale, + block_table=self.block_table, + block_lens=self.block_lens, + ) + + def read_indexer_view(self) -> "IndexerView": + """CSA-only. Returns FP4 indexer keys with their scales.""" + assert self.paged is not None and self.paged.indexer_keys_fp4 is not None + return IndexerView( + keys_fp4=self.paged.indexer_keys_fp4, + scale=self.paged.indexer_scale, + global_scale=self.paged.indexer_global_scale, + block_table=self.block_table, + block_lens=self.block_lens, + ) + + +# Typed views — simple dataclasses, no logic. The FMHA / indexer / SWA +# kernels accept these to keep their signatures clean. +@dataclass +class SWAView: + fp8: torch.Tensor + rope: torch.Tensor + inv_scale: torch.Tensor + positions: torch.Tensor + head: torch.Tensor + slots: torch.Tensor + + +@dataclass +class ClassicalView: + entries_fp8: torch.Tensor + entries_rope: torch.Tensor + inv_scale: torch.Tensor + block_table: torch.Tensor + block_lens: torch.Tensor + + +@dataclass +class IndexerView: + keys_fp4: torch.Tensor + scale: torch.Tensor + global_scale: torch.Tensor + block_table: torch.Tensor + block_lens: torch.Tensor diff --git a/dsv4/cache/manager.py b/dsv4/cache/manager.py new file mode 100644 index 00000000..8af0093d --- /dev/null +++ b/dsv4/cache/manager.py @@ -0,0 +1,200 @@ +"""KVCacheManager — owns all KV cache state for one model instance. + +Responsibilities: + - Build per-layer pools and allocators at startup. + - Hand out state-cache slots when requests are admitted. + - Hand out classical blocks when layers need to flush compression. + - Compose LayerCacheHandle for each layer per forward call. + - Reclaim slots and blocks on request completion. + +Not on the manager: + - On-disk prefix storage. (Paper §3.5.2 — deferred entirely.) + - Eviction policies. (Single-instance; requests run to completion.) + - Cross-instance coordination. +""" +from __future__ import annotations +from typing import List, Optional, Dict +import torch + +from dsv4.model.config import DSV4Config +from dsv4.model.layer_schedule import LayerSpec, AttentionType +from dsv4.cache.schema import LayerCacheSchema, build_schema, compute_block_budget +from dsv4.cache.allocator import BlockAllocator +from dsv4.cache.paged_cache import PagedKVPool +from dsv4.cache.state_cache import StateCachePool +from dsv4.cache.handle import LayerCacheHandle + + +class KVCacheManager: + def __init__( + self, + config: DSV4Config, + schedule: List[LayerSpec], + max_concurrent_requests: int, + max_context_tokens: int = 1_000_000, + # Per-layer-type block budget. If None, computed from + # max_context_tokens and max_concurrent_requests. + num_blocks_per_csa_layer: Optional[int] = None, + num_blocks_per_hca_layer: Optional[int] = None, + device: str = "cuda", + ): + self.config = config + self.schedule = schedule + self.max_concurrent_requests = max_concurrent_requests + self.device = device + + # ---- Per-layer schemas ---- + self.schemas: Dict[int, LayerCacheSchema] = { + spec.layer_idx: build_schema(config, spec) for spec in schedule + } + + # ---- Compute block budgets if not provided ---- + if num_blocks_per_csa_layer is None or num_blocks_per_hca_layer is None: + budget = compute_block_budget(config, schedule, max_context_tokens, + max_concurrent_requests) + num_blocks_per_csa_layer = num_blocks_per_csa_layer or budget.get("csa", 0) + num_blocks_per_hca_layer = num_blocks_per_hca_layer or budget.get("hca", 0) + + # ---- Per-layer pools ---- + # State cache exists for every layer. + self.state_pools: Dict[int, StateCachePool] = { + i: StateCachePool(schema, max_concurrent_requests, device) + for i, schema in self.schemas.items() + } + + # Classical paged pool only for compressed layers. + self.paged_pools: Dict[int, Optional[PagedKVPool]] = {} + self.allocators: Dict[int, Optional[BlockAllocator]] = {} + for i, schema in self.schemas.items(): + if schema.entries_per_block == 0: + self.paged_pools[i] = None + self.allocators[i] = None + else: + nb = (num_blocks_per_csa_layer + if schema.attn_type == AttentionType.CSA + else num_blocks_per_hca_layer) + self.paged_pools[i] = PagedKVPool(schema, nb, device) + self.allocators[i] = BlockAllocator(nb, device) + + # ---- Request state ---- + # Slot index per request, into state cache pools (same index in + # every layer). -1 = slot free. + self.request_slot_map: torch.Tensor = torch.full( + (max_concurrent_requests,), -1, dtype=torch.int32, device=device, + ) + + # Block table per request per layer: + # block_tables[layer_idx][request_slot, logical_block_idx] + # -> physical_block_idx + max_blocks = max_context_tokens // 128 # BLOCK_SIZE_ORIGINAL_TOKENS + self.max_blocks_per_request = max_blocks + self.block_tables: Dict[int, torch.Tensor] = {} + self.block_lens: Dict[int, torch.Tensor] = {} + for i, schema in self.schemas.items(): + if schema.entries_per_block > 0: + self.block_tables[i] = torch.full( + (max_concurrent_requests, max_blocks), -1, + dtype=torch.int32, device=device, + ) + self.block_lens[i] = torch.zeros( + (max_concurrent_requests,), dtype=torch.int32, device=device, + ) + + # ------------------------------------------------------------------ + # Request lifecycle (called between captured graphs) + # ------------------------------------------------------------------ + def admit_request(self) -> int: + """Allocate a state cache slot. Returns the slot index.""" + free = (self.request_slot_map == -1).nonzero(as_tuple=False) + if free.numel() == 0: + raise RuntimeError("max concurrent requests exceeded") + slot = int(free[0]) + self.request_slot_map[slot] = slot + return slot + + def release_request(self, slot: int) -> None: + """Return state cache slot and all associated blocks to free lists.""" + for layer_idx, alloc in self.allocators.items(): + if alloc is None: + continue + table = self.block_tables[layer_idx] + lens = self.block_lens[layer_idx] + valid = int(lens[slot]) + if valid > 0: + alloc.release(table[slot, :valid].clone()) + lens[slot] = 0 + table[slot].fill_(-1) + # Reset state cache slot. + for state in self.state_pools.values(): + state.reset_slot(slot) + self.request_slot_map[slot] = -1 + + # ------------------------------------------------------------------ + # Block allocation for compression flush (called between captures) + # ------------------------------------------------------------------ + def allocate_block(self, layer_idx: int, request_slot: int) -> int: + """Allocate one new classical block for a request. Returns block ID.""" + alloc = self.allocators[layer_idx] + assert alloc is not None, f"layer {layer_idx} has no classical pool" + block_id = alloc.acquire(1) + bid = int(block_id[0]) + # Append to the request's block table. + table = self.block_tables[layer_idx] + lens = self.block_lens[layer_idx] + pos = int(lens[request_slot]) + assert pos < self.max_blocks_per_request, "block table overflow" + table[request_slot, pos] = bid + lens[request_slot] = pos + 1 + return bid + + # ------------------------------------------------------------------ + # Per-forward handle construction (called INSIDE captured graph) + # ------------------------------------------------------------------ + def acquire( + self, + layer_idx: int, + request_slots: torch.Tensor, # [batch] int32 + positions: torch.Tensor, # [tokens] int32 + request_ids: torch.Tensor, # [tokens] int32 + ) -> LayerCacheHandle: + """Build the LayerCacheHandle for one layer's forward. + + No allocation happens here — critical for cudagraph safety. + """ + paged = self.paged_pools[layer_idx] + state = self.state_pools[layer_idx] + + if paged is not None: + # Pass the full tensors — no indexing, no allocation. + # The attention kernel indexes by request_slots internally. + block_table = self.block_tables[layer_idx] + block_lens = self.block_lens[layer_idx] + else: + block_table = None + block_lens = None + + return LayerCacheHandle( + paged=paged, + state=state, + request_slots=request_slots, + positions=positions, + request_ids=request_ids, + block_table=block_table, + block_lens=block_lens, + ) + + # ------------------------------------------------------------------ + # Diagnostics + # ------------------------------------------------------------------ + def memory_bytes(self) -> int: + """Total GPU memory used by all pools.""" + total = 0 + for pool in self.state_pools.values(): + total += pool.memory_bytes() + for pool in self.paged_pools.values(): + if pool is not None: + total += pool.memory_bytes() + for i, table in self.block_tables.items(): + total += table.numel() * table.element_size() + total += self.block_lens[i].numel() * self.block_lens[i].element_size() + return total diff --git a/dsv4/cache/paged_cache.py b/dsv4/cache/paged_cache.py index fce4b419..18cf9f7a 100644 --- a/dsv4/cache/paged_cache.py +++ b/dsv4/cache/paged_cache.py @@ -1,2 +1,91 @@ -"""Paged KV cache.""" -# TODO: Phase 3 +"""Storage for one layer's classical paged KV cache. + +Layout per block: + entries: [num_blocks, entries_per_block, head_dim - rope_dim] FP8 (uint8 view) + entries_r: [num_blocks, entries_per_block, rope_dim] BF16 + inv_scale: [num_blocks, entries_per_block] FP32 + +The FP8/BF16 split mirrors paper §2.3.4 ("BF16 for RoPE dims, FP8 for +the rest"). The kernel reads both halves and concatenates in registers. + +For CSA layers, a parallel pool stores indexer keys at the same block +granularity — same block ID maps to a block in both pools. +""" +from __future__ import annotations +from typing import Optional +import torch + +from dsv4.cache.schema import LayerCacheSchema + + +class PagedKVPool: + """Per-layer classical paged KV storage. + + Indexed by [physical_block_id, slot_in_block, ...]. + Both compressed entries and indexer keys (if applicable) are + indexed by the SAME physical_block_id so a CSA layer's two pools + share the block table. + """ + + def __init__( + self, + schema: LayerCacheSchema, + num_blocks: int, + device: str = "cuda", + ): + self.schema = schema + self.num_blocks = num_blocks + self.device = device + + nb = num_blocks + epb = schema.entries_per_block + hd = schema.entry_head_dim + rd = schema.rope_dim + fp8_dim = hd - rd + + # ---- Compressed entries ---- + # FP8 stored as uint8 (we view as float8_e4m3fn at read time). + self.entries_fp8 = torch.zeros( + (nb, epb, fp8_dim), dtype=torch.uint8, device=device, + ) + # BF16 RoPE'd half — no quantization. + self.entries_rope = torch.zeros( + (nb, epb, rd), dtype=torch.bfloat16, device=device, + ) + # Per-entry inverse scale (for FP8 dequant in attention kernel). + self.inv_scale = torch.ones( + (nb, epb), dtype=torch.float32, device=device, + ) + + # ---- Indexer keys (CSA only) ---- + if schema.indexer_entries_per_block > 0: + i_epb = schema.indexer_entries_per_block + i_hd = schema.indexer_head_dim + # Indexer QK is FP4 per paper §2.3.4 — but we store the keys + # post-quant. uint8 = 2 FP4 packed per byte. + self.indexer_keys_fp4 = torch.zeros( + (nb, i_epb, i_hd // 2), dtype=torch.uint8, device=device, + ) + # Per-block-vector scale for the FP4 (one E4M3 scalar per + # 16-element group, per the NVFP4 quantization scheme). + self.indexer_scale = torch.ones( + (nb, i_epb, i_hd // 16), + dtype=torch.float8_e4m3fn, device=device, + ) + self.indexer_global_scale = torch.ones( + (nb,), dtype=torch.float32, device=device, + ) + else: + self.indexer_keys_fp4 = None + self.indexer_scale = None + self.indexer_global_scale = None + + def memory_bytes(self) -> int: + """Total GPU memory used by this pool.""" + total = 0 + for name in ("entries_fp8", "entries_rope", "inv_scale", + "indexer_keys_fp4", "indexer_scale", "indexer_global_scale"): + t = getattr(self, name) + if t is not None: + total += t.numel() * t.element_size() + return total diff --git a/dsv4/cache/schema.py b/dsv4/cache/schema.py new file mode 100644 index 00000000..78b58e10 --- /dev/null +++ b/dsv4/cache/schema.py @@ -0,0 +1,129 @@ +"""Per-layer KV cache shape. + +Computed once per layer at engine startup from the LayerSpec. The +schema is what tells the allocator how big each pool slot is and what +sub-regions exist (compressed entries / indexer keys / SWA window / +uncompressed tail). +""" +from __future__ import annotations +from dataclasses import dataclass +from typing import Optional + +from dsv4.model.config import DSV4Config +from dsv4.model.layer_schedule import LayerSpec, AttentionType + + +# Block size is invariant for DSV4 — derived from compression ratios. +# lcm(m, m') = lcm(4, 128) = 128 original tokens per block. +# Holds 128/4 = 32 CSA entries OR 128/128 = 1 HCA entry per block. +BLOCK_SIZE_ORIGINAL_TOKENS = 128 + + +@dataclass(frozen=True) +class LayerCacheSchema: + """Cache layout for one transformer layer. + + Fields with `_per_block` are the dimensions of one block in the + classical paged pool. `_per_state_slot` are dimensions of one + request's slot in the state cache. + + All sizes are in number of entries — bytes come from the dtypes. + """ + layer_idx: int + attn_type: AttentionType + + # ---- Classical paged cache (compressed entries) ---- + # Number of compressed entries in one block of BLOCK_SIZE_ORIGINAL_TOKENS + # original tokens. For HCA m'=128 this is 1; for CSA m=4 this is 32. + # SWA-only layers have no classical pool — entries_per_block = 0. + entries_per_block: int + # Width of one entry (head_dim). + entry_head_dim: int + # RoPE-applied dimensions (BF16). Others FP8. + rope_dim: int + + # ---- Indexer pool (CSA only) ---- + # Compressed indexer keys, one per compressed entry. + indexer_entries_per_block: int # 32 for CSA, 0 for HCA/SWA + indexer_head_dim: int # 128 for CSA, 0 for others + + # ---- State cache (SWA window + uncompressed tail) ---- + swa_window_size: int # 128 for all layer types + # Uncompressed tail buffer — needed only for compressed layers. + # CSA: up to m-1 = 3 pending tokens before flushing compression. + # HCA: up to m'-1 = 127 pending tokens. + # SWA-only: no tail (no compression branch). + tail_buffer_size: int + + # Per-token inverse scale storage (for FP8 dequant). One FP32 scalar + # per stored entry/window-slot. + needs_inv_scale: bool = True + + +def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema: + """Derive cache schema for a single layer from architectural config.""" + if spec.attn == AttentionType.CSA: + return LayerCacheSchema( + layer_idx=spec.layer_idx, + attn_type=AttentionType.CSA, + entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio, + entry_head_dim=config.head_dim, + rope_dim=config.rope_dim, + indexer_entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio, + indexer_head_dim=config.indexer_head_dim, + swa_window_size=config.sliding_window, + tail_buffer_size=config.csa_compression_ratio - 1, + ) + elif spec.attn == AttentionType.HCA: + return LayerCacheSchema( + layer_idx=spec.layer_idx, + attn_type=AttentionType.HCA, + entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.hca_compression_ratio, + entry_head_dim=config.head_dim, + rope_dim=config.rope_dim, + indexer_entries_per_block=0, + indexer_head_dim=0, + swa_window_size=config.sliding_window, + tail_buffer_size=config.hca_compression_ratio - 1, + ) + else: # SWA-only + return LayerCacheSchema( + layer_idx=spec.layer_idx, + attn_type=AttentionType.SWA, + entries_per_block=0, + entry_head_dim=config.head_dim, + rope_dim=config.rope_dim, + indexer_entries_per_block=0, + indexer_head_dim=0, + swa_window_size=config.sliding_window, + tail_buffer_size=0, + ) + + +def compute_block_budget( + config: DSV4Config, + schedule: list[LayerSpec], + max_context_tokens: int, + max_concurrent_requests: int, +) -> dict[str, int]: + """Compute per-layer-type block counts for the allocator. + + Returns {layer_type: num_blocks} where layer_type is 'csa' or 'hca'. + SWA-only layers need no classical blocks. + + Block budget = max_concurrent_requests * (max_context / BLOCK_SIZE). + Add 10% headroom for fragmentation. + """ + blocks_per_request = max_context_tokens // BLOCK_SIZE_ORIGINAL_TOKENS + headroom = 1.10 + result = {} + for spec in schedule: + if spec.attn == AttentionType.CSA: + key = "csa" + elif spec.attn == AttentionType.HCA: + key = "hca" + else: + continue + total = int(max_concurrent_requests * blocks_per_request * headroom) + result[key] = max(result.get(key, 0), total) + return result diff --git a/dsv4/cache/state_cache.py b/dsv4/cache/state_cache.py index e19485ba..454d3212 100644 --- a/dsv4/cache/state_cache.py +++ b/dsv4/cache/state_cache.py @@ -1,2 +1,96 @@ -"""State cache for KV.""" -# TODO: Phase 3 +"""State cache: SWA window + uncompressed tail buffer. + +One slot per active request. Slot index is fixed for a request's +lifetime — the manager hands out slot indices at request admission +and reclaims them at completion. + +Per paper §3.5.1: SWA and tail tokens are state-space-like — they +depend only on the current position, not on a paged history. No +block table; a flat [max_requests, ...] tensor. +""" +from __future__ import annotations +import torch + +from dsv4.cache.schema import LayerCacheSchema, AttentionType + + +class StateCachePool: + """Per-layer state cache (SWA window + uncompressed tail). + + Storage layout per slot: + swa_fp8: [n_win, head_dim - rope_dim] FP8 raw KV in window + swa_rope: [n_win, rope_dim] BF16 RoPE'd half + swa_inv: [n_win] FP32 per-token inv scale + swa_pos: [n_win] int32 — absolute position + of each window slot (-1 if invalid) + + tail_ka: [tail_size, head_dim] BF16 raw — pending tokens + not yet compressed + tail_za: [tail_size, head_dim] BF16 — compression weights + (Z stream for CSA, single Z for HCA) + tail_kb: [tail_size, head_dim] BF16 — second stream (CSA only) + tail_zb: [tail_size, head_dim] BF16 — second Z stream (CSA only) + tail_len: scalar int32 — how many tail entries are valid + """ + + def __init__( + self, + schema: LayerCacheSchema, + max_requests: int, + device: str = "cuda", + ): + self.schema = schema + self.max_requests = max_requests + self.device = device + + mr = max_requests + nw = schema.swa_window_size + hd = schema.entry_head_dim + rd = schema.rope_dim + fp8 = hd - rd + + # SWA window — circular within each slot. Layer's attention + # kernel uses swa_pos to mask invalid entries. + self.swa_fp8 = torch.zeros((mr, nw, fp8), dtype=torch.uint8, device=device) + self.swa_rope = torch.zeros((mr, nw, rd), dtype=torch.bfloat16, device=device) + self.swa_inv = torch.ones((mr, nw), dtype=torch.float32, device=device) + self.swa_pos = torch.full((mr, nw), -1, dtype=torch.int32, device=device) + # Next write position within each slot's ring buffer. + self.swa_head = torch.zeros((mr,), dtype=torch.int32, device=device) + + # Tail buffer — only non-empty for compressed layers. + tail = schema.tail_buffer_size + if tail > 0: + # For CSA we need two streams (Ca/Cb, Za/Zb) since the + # compressor uses overlapping pairs. HCA only needs one + # stream. Store both; HCA leaves the b-channel zero. + self.tail_ka = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device) + self.tail_za = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device) + if schema.attn_type == AttentionType.CSA: + self.tail_kb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device) + self.tail_zb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device) + else: + self.tail_kb = None + self.tail_zb = None + self.tail_len = torch.zeros((mr,), dtype=torch.int32, device=device) + else: + self.tail_ka = self.tail_kb = None + self.tail_za = self.tail_zb = None + self.tail_len = None + + def reset_slot(self, slot: int) -> None: + """Clear a request's state after completion.""" + self.swa_pos[slot].fill_(-1) + self.swa_head[slot] = 0 + if self.tail_len is not None: + self.tail_len[slot] = 0 + + def memory_bytes(self) -> int: + """Total GPU memory used by this pool.""" + total = 0 + for name in ("swa_fp8", "swa_rope", "swa_inv", "swa_pos", "swa_head", + "tail_ka", "tail_za", "tail_kb", "tail_zb", "tail_len"): + t = getattr(self, name) + if t is not None: + total += t.numel() * t.element_size() + return total diff --git a/dsv4/kernels/cache/__init__.py b/dsv4/kernels/cache/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/dsv4/kernels/cache/__init__.py @@ -0,0 +1 @@ + diff --git a/dsv4/kernels/cache/append_swa.py b/dsv4/kernels/cache/append_swa.py new file mode 100644 index 00000000..b1c0c4d8 --- /dev/null +++ b/dsv4/kernels/cache/append_swa.py @@ -0,0 +1,51 @@ +"""Python wrapper for the append_swa CUDA kernel. + +Writes raw BF16 KV into the FP8/BF16 split state cache layout. +Quantizes the non-RoPE half BF16 -> FP8 (E4M3 amax-based scaling), +writes the RoPE half as-is, computes per-token inverse scale, and +updates the ring buffer head + position field. + +One block per token. Threads cooperatively: + 1. Compute amax over fp8-dim elements (warp reduce). + 2. Quantize BF16 -> FP8 with per-token scale. + 3. Write FP8 entries + BF16 RoPE entries + inv_scale + position. + 4. Atomic increment ring buffer head. +""" +import os +import torch +from torch.utils.cpp_extension import load + +_kernel_module = None + + +def _get_kernel_module(): + global _kernel_module + if _kernel_module is not None: + return _kernel_module + kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda") + _kernel_module = load( + name="append_swa", + sources=[os.path.join(kernel_dir, "append_swa.cu")], + extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"], + verbose=False, + ) + return _kernel_module + + +def append_swa_kernel( + raw_kv: torch.Tensor, # (T, head_dim) BF16 + request_slots: torch.Tensor, # (T,) int32 + positions: torch.Tensor, # (T,) int32 + swa_fp8: torch.Tensor, # (max_req, n_win, fp8_dim) uint8 + swa_rope: torch.Tensor, # (max_req, n_win, rope_dim) BF16 + swa_inv: torch.Tensor, # (max_req, n_win) FP32 + swa_pos: torch.Tensor, # (max_req, n_win) int32 + swa_head: torch.Tensor, # (max_req,) int32 + rope_dim: int, +): + mod = _get_kernel_module() + mod.append_swa( + raw_kv, request_slots, positions, + swa_fp8, swa_rope, swa_inv, swa_pos, swa_head, + rope_dim, + ) diff --git a/dsv4/kernels/cuda/append_swa.cu b/dsv4/kernels/cuda/append_swa.cu new file mode 100644 index 00000000..dd69ba6e --- /dev/null +++ b/dsv4/kernels/cuda/append_swa.cu @@ -0,0 +1,165 @@ +// append_swa.cu — write raw BF16 KV into the SWA ring buffer. +// +// One block per token. Threads cooperatively: +// 1. Compute amax over fp8-dim elements (warp reduce). +// 2. Quantize BF16 -> FP8 E4M3 with per-token scale. +// 3. Write FP8 entries + BF16 RoPE entries + inv_scale + position. +// 4. Atomic increment ring buffer head. +// +// Paper §2.3.4: BF16 for RoPE'd dims, FP8 for the rest. +// Per-token inverse scale stored for dequant in the attention kernel. + +#include +#include +#include +#include +#include + +#include + +// Warp-level amax reduction +__device__ __forceinline__ float warp_reduce_amax(float val) { + for (int offset = 16; offset > 0; offset >>= 1) { + float other = __shfl_down_sync(0xffffffff, val, offset); + val = fmaxf(val, fabsf(other)); + } + return val; +} + +// Warp-level sum for counting valid entries +__device__ __forceinline__ float warp_reduce_sum(float val) { + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + return val; +} + +__global__ void append_swa_kernel( + const __nv_bfloat16* __restrict__ raw_kv, // [T, head_dim] + const int32_t* __restrict__ request_slots, // [T] -> slot in state pool + const int32_t* __restrict__ positions, // [T] -> absolute position + // State cache pool — written in place. + uint8_t* __restrict__ swa_fp8, // [max_req, n_win, fp8_dim] + __nv_bfloat16* __restrict__ swa_rope, // [max_req, n_win, rope_dim] + float* __restrict__ swa_inv, // [max_req, n_win] + int32_t* __restrict__ swa_pos, // [max_req, n_win] + int32_t* __restrict__ swa_head, // [max_req] + int T, int n_win, int head_dim, int rope_dim +) { + int t = blockIdx.x; + if (t >= T) return; + + int lane = threadIdx.x; + int warp_size = blockDim.x; // expect 128 threads per block + + int slot = request_slots[t]; + int pos = positions[t]; + int fp8_dim = head_dim - rope_dim; + + // ---- Step 1: Compute amax over fp8_dim elements ---- + // Each thread processes strided elements of the fp8 half. + float local_amax = 0.0f; + for (int i = lane; i < fp8_dim; i += warp_size) { + float val = __bfloat162float(raw_kv[t * head_dim + i]); + local_amax = fmaxf(local_amax, fabsf(val)); + } + + // Warp-level amax reduction (works for warp_size <= 32). + // For 128 threads, we need to reduce across 4 warps. + float block_amax = 0.0f; + // Intra-warp reduce + float warp_amax = warp_reduce_amax(local_amax); + // Lane 0 of each warp writes to shared memory + __shared__ float smem_amax[4]; // max 4 warps for 128 threads + if (lane % 32 == 0) { + smem_amax[lane / 32] = warp_amax; + } + __syncthreads(); + if (lane < 32) { + float v = (lane < (warp_size + 31) / 32) ? smem_amax[lane] : 0.0f; + block_amax = warp_reduce_amax(v); + } + __syncthreads(); + // Broadcast block_amax to all threads + __shared__ float s_inv_scale; + if (lane == 0) { + float scale = block_amax / 448.0f; // FP8 E4M3 max = 448 + if (scale < 1e-12f) scale = 1e-12f; // avoid div-by-zero + s_inv_scale = scale; + } + __syncthreads(); + float inv_scale_val = s_inv_scale; + + // ---- Step 2: Atomic increment ring buffer head ---- + // Only one thread per block does the atomic + __shared__ int slot_in_window; + if (lane == 0) { + slot_in_window = atomicAdd(&swa_head[slot], 1) % n_win; + } + __syncthreads(); + + // ---- Step 3: Write FP8 entries ---- + for (int i = lane; i < fp8_dim; i += warp_size) { + float val = __bfloat162float(raw_kv[t * head_dim + i]); + float quantized = val / inv_scale_val; + // Clamp to FP8 E4M3 range [-448, 448] + quantized = fmaxf(-448.0f, fminf(448.0f, quantized)); + // Convert to FP8 E4M3 + __nv_fp8_e4m3 fp8_val; + fp8_val.__x = __nv_fp8_e4m3(quantized).__x; + swa_fp8[slot * n_win * fp8_dim + slot_in_window * fp8_dim + i] = fp8_val.__x; + } + + // ---- Step 4: Write BF16 RoPE entries ---- + for (int i = lane; i < rope_dim; i += warp_size) { + __nv_bfloat16 val = raw_kv[t * head_dim + fp8_dim + i]; + swa_rope[slot * n_win * rope_dim + slot_in_window * rope_dim + i] = val; + } + + // ---- Step 5: Write metadata (single thread) ---- + if (lane == 0) { + swa_inv[slot * n_win + slot_in_window] = inv_scale_val; + swa_pos[slot * n_win + slot_in_window] = pos; + } +} + + +std::tuple +append_swa_cuda( + torch::Tensor raw_kv, // [T, head_dim] BF16 + torch::Tensor request_slots, // [T] int32 + torch::Tensor positions, // [T] int32 + torch::Tensor swa_fp8, // [max_req, n_win, fp8_dim] uint8 + torch::Tensor swa_rope, // [max_req, n_win, rope_dim] BF16 + torch::Tensor swa_inv, // [max_req, n_win] FP32 + torch::Tensor swa_pos, // [max_req, n_win] int32 + torch::Tensor swa_head, // [max_req] int32 + int64_t rope_dim +) { + int T = raw_kv.size(0); + int head_dim = raw_kv.size(1); + int n_win = swa_fp8.size(1); + + int threads = 128; + int blocks = T; + + append_swa_kernel<<>>( + reinterpret_cast(raw_kv.data_ptr()), + request_slots.data_ptr(), + positions.data_ptr(), + swa_fp8.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(swa_rope.data_ptr()), + swa_inv.data_ptr(), + swa_pos.data_ptr(), + swa_head.data_ptr(), + T, n_win, head_dim, static_cast(rope_dim) + ); + + C10_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(swa_fp8, swa_rope, swa_inv, swa_pos, swa_head); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("append_swa", &append_swa_cuda, "Append SWA kernel"); +} diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py new file mode 100644 index 00000000..56170e47 --- /dev/null +++ b/tests/unit/test_cache.py @@ -0,0 +1,252 @@ +"""Tests for KV cache: schema, allocator, pools, manager lifecycle.""" + +import torch +import pytest + +from dsv4.model.config import DSV4Config +from dsv4.model.layer_schedule import build_schedule, AttentionType, RouterMode, LayerSpec +from dsv4.cache.schema import build_schema, compute_block_budget, BLOCK_SIZE_ORIGINAL_TOKENS +from dsv4.cache.allocator import BlockAllocator +from dsv4.cache.paged_cache import PagedKVPool +from dsv4.cache.state_cache import StateCachePool +from dsv4.cache.manager import KVCacheManager + + +# ---- Schema tests ---- + +def test_csa_schema(): + config = DSV4Config.pro() + spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, + ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE, + router_mode=RouterMode.HASH) + schema = build_schema(config, spec) + assert schema.entries_per_block == 32 # 128 / 4 + assert schema.indexer_entries_per_block == 32 + assert schema.tail_buffer_size == 3 # m - 1 + assert schema.swa_window_size == 128 + + +def test_hca_schema(): + config = DSV4Config.pro() + spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA, + ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE, + router_mode=RouterMode.DENSE) + schema = build_schema(config, spec) + assert schema.entries_per_block == 1 # 128 / 128 + assert schema.indexer_entries_per_block == 0 + assert schema.tail_buffer_size == 127 # m' - 1 + + +def test_swa_schema(): + config = DSV4Config.flash() + spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA, + ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE, + router_mode=RouterMode.HASH) + schema = build_schema(config, spec) + assert schema.entries_per_block == 0 + assert schema.indexer_entries_per_block == 0 + assert schema.tail_buffer_size == 0 + assert schema.swa_window_size == 128 + + +def test_schema_from_schedule(): + """Every layer in a full schedule produces a valid schema.""" + config = DSV4Config.flash() + schedule = build_schedule(config) + for spec in schedule: + schema = build_schema(config, spec) + assert schema.swa_window_size > 0 + assert schema.entry_head_dim == config.head_dim + if spec.attn == AttentionType.SWA: + assert schema.entries_per_block == 0 + else: + assert schema.entries_per_block > 0 + + +def test_block_budget(): + config = DSV4Config.pro() + schedule = build_schedule(config) + budget = compute_block_budget(config, schedule, 1_000_000, 16) + assert "csa" in budget + assert "hca" in budget + assert budget["csa"] > budget["hca"] # CSA uses more blocks per request + + +# ---- Allocator tests ---- + +def test_acquire_release_roundtrip(): + alloc = BlockAllocator(num_total_blocks=1024) + a = alloc.acquire(10) + assert alloc.num_free == 1014 + b = alloc.acquire(5) + assert alloc.num_free == 1009 + alloc.release(a) + assert alloc.num_free == 1019 + alloc.release(b) + assert alloc.num_free == 1024 + # Re-acquire works after release. + c = alloc.acquire(20) + assert alloc.num_free == 1004 + + +def test_oom_raises(): + alloc = BlockAllocator(num_total_blocks=4) + alloc.acquire(4) + with pytest.raises(RuntimeError, match="OOM"): + alloc.acquire(1) + + +def test_acquire_returns_unique_ids(): + alloc = BlockAllocator(num_total_blocks=100) + a = alloc.acquire(50) + b = alloc.acquire(50) + assert len(torch.intersect1d(a, b)) == 0 + + +# ---- Pool shape tests ---- + +def test_paged_pool_shapes_csa(): + from dsv4.model.layer_schedule import FFNType + config = DSV4Config.pro() + spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, + ffn=FFNType.MOE, router_mode=RouterMode.HASH) + schema = build_schema(config, spec) + pool = PagedKVPool(schema, num_blocks=16) + assert pool.entries_fp8.shape == (16, 32, config.head_dim - config.rope_dim) + assert pool.entries_rope.shape == (16, 32, config.rope_dim) + assert pool.inv_scale.shape == (16, 32) + assert pool.indexer_keys_fp4 is not None + assert pool.indexer_keys_fp4.shape[1] == 32 + + +def test_paged_pool_shapes_hca(): + from dsv4.model.layer_schedule import FFNType + config = DSV4Config.pro() + spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA, + ffn=FFNType.MOE, router_mode=RouterMode.DENSE) + schema = build_schema(config, spec) + pool = PagedKVPool(schema, num_blocks=256) + assert pool.entries_fp8.shape == (256, 1, config.head_dim - config.rope_dim) + assert pool.entries_rope.shape == (256, 1, config.rope_dim) + assert pool.indexer_keys_fp4 is None + + +def test_state_pool_shapes_csa(): + from dsv4.model.layer_schedule import FFNType + config = DSV4Config.pro() + spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, + ffn=FFNType.MOE, router_mode=RouterMode.HASH) + schema = build_schema(config, spec) + pool = StateCachePool(schema, max_requests=8) + assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim) + assert pool.swa_rope.shape == (8, 128, config.rope_dim) + assert pool.tail_ka is not None + assert pool.tail_kb is not None # CSA has both streams + assert pool.tail_len is not None + + +def test_state_pool_shapes_hca(): + from dsv4.model.layer_schedule import FFNType + config = DSV4Config.pro() + spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA, + ffn=FFNType.MOE, router_mode=RouterMode.DENSE) + schema = build_schema(config, spec) + pool = StateCachePool(schema, max_requests=8) + assert pool.tail_ka is not None + assert pool.tail_kb is None # HCA only one stream + assert pool.tail_za is not None + assert pool.tail_len is not None + + +def test_state_pool_shapes_swa(): + from dsv4.model.layer_schedule import FFNType + config = DSV4Config.flash() + spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA, + ffn=FFNType.MOE, router_mode=RouterMode.HASH) + schema = build_schema(config, spec) + pool = StateCachePool(schema, max_requests=8) + assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim) + assert pool.tail_ka is None # No tail for SWA-only + assert pool.tail_len is None + + +# ---- Manager lifecycle tests ---- + +def test_admit_release_recycles_slot(): + config = DSV4Config.flash() + schedule = build_schedule(config) + mgr = KVCacheManager(config, schedule, max_concurrent_requests=4, + num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64) + s1 = mgr.admit_request() + mgr.release_request(s1) + s2 = mgr.admit_request() + assert s1 == s2 # slot was recycled + + +def test_admit_exhaustion(): + config = DSV4Config.flash() + schedule = build_schedule(config) + mgr = KVCacheManager(config, schedule, max_concurrent_requests=2, + num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64) + mgr.admit_request() + mgr.admit_request() + with pytest.raises(RuntimeError, match="concurrent"): + mgr.admit_request() + + +def test_handle_construction_no_alloc(): + """acquire() should not allocate GPU memory — critical for cudagraph.""" + config = DSV4Config.flash() + schedule = build_schedule(config) + mgr = KVCacheManager(config, schedule, max_concurrent_requests=4, + num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64) + slot = mgr.admit_request() + torch.cuda.synchronize() + before = torch.cuda.memory_allocated() + handle = mgr.acquire( + layer_idx=0, + request_slots=torch.tensor([slot], dtype=torch.int32, device="cuda"), + positions=torch.tensor([0], dtype=torch.int32, device="cuda"), + request_ids=torch.tensor([0], dtype=torch.int32, device="cuda"), + ) + torch.cuda.synchronize() + after = torch.cuda.memory_allocated() + # Allow small variance from tensor view creation, but no large alloc + assert after - before < 1024, f"acquire() allocated {after - before} bytes — breaks cudagraph" + assert handle.paged is None # layer 0 is SWA + mgr.release_request(slot) + + +def test_manager_memory_tracking(): + config = DSV4Config.flash() + schedule = build_schedule(config) + mgr = KVCacheManager(config, schedule, max_concurrent_requests=4, + num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64) + total = mgr.memory_bytes() + assert total > 0 + # Rough sanity: should be in the MB range for this config + assert total > 1_000_000 # at least 1 MB + assert total < 10_000_000_000 # less than 10 GB + + +def test_full_flash_stack_construction(): + """Construct manager with all 43 Flash layers — pools for every layer.""" + config = DSV4Config.flash() + schedule = build_schedule(config) + mgr = KVCacheManager(config, schedule, max_concurrent_requests=4, + num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64) + # 2 SWA layers (no paged pool) + 41 compressed layers + assert len(mgr.paged_pools) == 43 + assert mgr.paged_pools[0] is None # SWA + assert mgr.paged_pools[1] is None # SWA + assert mgr.paged_pools[2] is not None # CSA + assert mgr.paged_pools[3] is not None # HCA + + # All state pools present + assert len(mgr.state_pools) == 43 + for i in range(43): + assert mgr.state_pools[i] is not None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])