Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
97 lines
4.0 KiB
Python
97 lines
4.0 KiB
Python
"""State cache: SWA window + uncompressed tail buffer.
|
|
|
|
One slot per active request. Slot index is fixed for a request's
|
|
lifetime — the manager hands out slot indices at request admission
|
|
and reclaims them at completion.
|
|
|
|
Per paper §3.5.1: SWA and tail tokens are state-space-like — they
|
|
depend only on the current position, not on a paged history. No
|
|
block table; a flat [max_requests, ...] tensor.
|
|
"""
|
|
from __future__ import annotations
|
|
import torch
|
|
|
|
from dsv4.cache.schema import LayerCacheSchema, AttentionType
|
|
|
|
|
|
class StateCachePool:
|
|
"""Per-layer state cache (SWA window + uncompressed tail).
|
|
|
|
Storage layout per slot:
|
|
swa_fp8: [n_win, head_dim - rope_dim] FP8 raw KV in window
|
|
swa_rope: [n_win, rope_dim] BF16 RoPE'd half
|
|
swa_inv: [n_win] FP32 per-token inv scale
|
|
swa_pos: [n_win] int32 — absolute position
|
|
of each window slot (-1 if invalid)
|
|
|
|
tail_ka: [tail_size, head_dim] BF16 raw — pending tokens
|
|
not yet compressed
|
|
tail_za: [tail_size, head_dim] BF16 — compression weights
|
|
(Z stream for CSA, single Z for HCA)
|
|
tail_kb: [tail_size, head_dim] BF16 — second stream (CSA only)
|
|
tail_zb: [tail_size, head_dim] BF16 — second Z stream (CSA only)
|
|
tail_len: scalar int32 — how many tail entries are valid
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
schema: LayerCacheSchema,
|
|
max_requests: int,
|
|
device: str = "cuda",
|
|
):
|
|
self.schema = schema
|
|
self.max_requests = max_requests
|
|
self.device = device
|
|
|
|
mr = max_requests
|
|
nw = schema.swa_window_size
|
|
hd = schema.entry_head_dim
|
|
rd = schema.rope_dim
|
|
fp8 = hd - rd
|
|
|
|
# SWA window — circular within each slot. Layer's attention
|
|
# kernel uses swa_pos to mask invalid entries.
|
|
self.swa_fp8 = torch.zeros((mr, nw, fp8), dtype=torch.uint8, device=device)
|
|
self.swa_rope = torch.zeros((mr, nw, rd), dtype=torch.bfloat16, device=device)
|
|
self.swa_inv = torch.ones((mr, nw), dtype=torch.float32, device=device)
|
|
self.swa_pos = torch.full((mr, nw), -1, dtype=torch.int32, device=device)
|
|
# Next write position within each slot's ring buffer.
|
|
self.swa_head = torch.zeros((mr,), dtype=torch.int32, device=device)
|
|
|
|
# Tail buffer — only non-empty for compressed layers.
|
|
tail = schema.tail_buffer_size
|
|
if tail > 0:
|
|
# For CSA we need two streams (Ca/Cb, Za/Zb) since the
|
|
# compressor uses overlapping pairs. HCA only needs one
|
|
# stream. Store both; HCA leaves the b-channel zero.
|
|
self.tail_ka = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
|
self.tail_za = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
|
if schema.attn_type == AttentionType.CSA:
|
|
self.tail_kb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
|
self.tail_zb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
|
else:
|
|
self.tail_kb = None
|
|
self.tail_zb = None
|
|
self.tail_len = torch.zeros((mr,), dtype=torch.int32, device=device)
|
|
else:
|
|
self.tail_ka = self.tail_kb = None
|
|
self.tail_za = self.tail_zb = None
|
|
self.tail_len = None
|
|
|
|
def reset_slot(self, slot: int) -> None:
|
|
"""Clear a request's state after completion."""
|
|
self.swa_pos[slot].fill_(-1)
|
|
self.swa_head[slot] = 0
|
|
if self.tail_len is not None:
|
|
self.tail_len[slot] = 0
|
|
|
|
def memory_bytes(self) -> int:
|
|
"""Total GPU memory used by this pool."""
|
|
total = 0
|
|
for name in ("swa_fp8", "swa_rope", "swa_inv", "swa_pos", "swa_head",
|
|
"tail_ka", "tail_za", "tail_kb", "tail_zb", "tail_len"):
|
|
t = getattr(self, name)
|
|
if t is not None:
|
|
total += t.numel() * t.element_size()
|
|
return total
|