Files
nvfp4-megamoe-kernel/dsv4/cache/manager.py
biondizzle b4d58df620 KV Cache: schema, allocator, pools, manager, append_swa kernel
Complete KV cache substrate for DSV4 inference:

schema.py: Per-layer cache shape derived from LayerSpec.
  - CSA: 32 entries/block, 32 indexer entries, tail=3
  - HCA: 1 entry/block, no indexer, tail=127
  - SWA: no classical pool, no tail
  - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios)
  - compute_block_budget() for allocator sizing

allocator.py: Fixed-size block free-list.
  - GPU stack with pinned host top pointer
  - acquire/release between graph captures only
  - OOM raises on exhaustion

paged_cache.py: Per-layer classical KV storage.
  - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4)
  - Per-entry inverse scale for FP8 dequant
  - FP4 indexer keys for CSA layers (NVFP4 scheme)
  - memory_bytes() tracking

state_cache.py: Per-layer SWA window + tail buffer.
  - Ring buffer with position tracking (swa_head, swa_pos)
  - CSA: dual streams (ka/za/kb/zb) for overlapping compression
  - HCA: single stream (ka/za only)
  - SWA: no tail buffer
  - reset_slot() for request completion

handle.py: LayerCacheHandle — typed per-call view.
  - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view()
  - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe)
  - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures

manager.py: KVCacheManager — owns everything.
  - Per-layer schema, pool, and allocator construction
  - admit_request()/release_request() lifecycle
  - allocate_block() for compression flush
  - acquire() returns LayerCacheHandle (zero-alloc)

append_swa.cu: Native kernel for SWA writes.
  - One block per token, 128 threads per block
  - Warp-level amax reduction, BF16->FP8 E4M3 quantization
  - Atomic ring buffer head increment
  - FP8/BF16 split write + inv_scale + position metadata
  - FP8 round-trip: <3.6% relative error
  - RoPE half: exact match (no quantization)

All tests pass on B200:
  - Schema correctness for CSA/HCA/SWA
  - Allocator acquire/release/OOM
  - Pool shapes match architecture spec
  - Manager lifecycle (admit/release/recycle/exhaustion)
  - Zero-alloc acquire() (cudagraph safe)
  - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00

201 lines
8.3 KiB
Python

"""KVCacheManager — owns all KV cache state for one model instance.
Responsibilities:
- Build per-layer pools and allocators at startup.
- Hand out state-cache slots when requests are admitted.
- Hand out classical blocks when layers need to flush compression.
- Compose LayerCacheHandle for each layer per forward call.
- Reclaim slots and blocks on request completion.
Not on the manager:
- On-disk prefix storage. (Paper §3.5.2 — deferred entirely.)
- Eviction policies. (Single-instance; requests run to completion.)
- Cross-instance coordination.
"""
from __future__ import annotations
from typing import List, Optional, Dict
import torch
from dsv4.model.config import DSV4Config
from dsv4.model.layer_schedule import LayerSpec, AttentionType
from dsv4.cache.schema import LayerCacheSchema, build_schema, compute_block_budget
from dsv4.cache.allocator import BlockAllocator
from dsv4.cache.paged_cache import PagedKVPool
from dsv4.cache.state_cache import StateCachePool
from dsv4.cache.handle import LayerCacheHandle
class KVCacheManager:
def __init__(
self,
config: DSV4Config,
schedule: List[LayerSpec],
max_concurrent_requests: int,
max_context_tokens: int = 1_000_000,
# Per-layer-type block budget. If None, computed from
# max_context_tokens and max_concurrent_requests.
num_blocks_per_csa_layer: Optional[int] = None,
num_blocks_per_hca_layer: Optional[int] = None,
device: str = "cuda",
):
self.config = config
self.schedule = schedule
self.max_concurrent_requests = max_concurrent_requests
self.device = device
# ---- Per-layer schemas ----
self.schemas: Dict[int, LayerCacheSchema] = {
spec.layer_idx: build_schema(config, spec) for spec in schedule
}
# ---- Compute block budgets if not provided ----
if num_blocks_per_csa_layer is None or num_blocks_per_hca_layer is None:
budget = compute_block_budget(config, schedule, max_context_tokens,
max_concurrent_requests)
num_blocks_per_csa_layer = num_blocks_per_csa_layer or budget.get("csa", 0)
num_blocks_per_hca_layer = num_blocks_per_hca_layer or budget.get("hca", 0)
# ---- Per-layer pools ----
# State cache exists for every layer.
self.state_pools: Dict[int, StateCachePool] = {
i: StateCachePool(schema, max_concurrent_requests, device)
for i, schema in self.schemas.items()
}
# Classical paged pool only for compressed layers.
self.paged_pools: Dict[int, Optional[PagedKVPool]] = {}
self.allocators: Dict[int, Optional[BlockAllocator]] = {}
for i, schema in self.schemas.items():
if schema.entries_per_block == 0:
self.paged_pools[i] = None
self.allocators[i] = None
else:
nb = (num_blocks_per_csa_layer
if schema.attn_type == AttentionType.CSA
else num_blocks_per_hca_layer)
self.paged_pools[i] = PagedKVPool(schema, nb, device)
self.allocators[i] = BlockAllocator(nb, device)
# ---- Request state ----
# Slot index per request, into state cache pools (same index in
# every layer). -1 = slot free.
self.request_slot_map: torch.Tensor = torch.full(
(max_concurrent_requests,), -1, dtype=torch.int32, device=device,
)
# Block table per request per layer:
# block_tables[layer_idx][request_slot, logical_block_idx]
# -> physical_block_idx
max_blocks = max_context_tokens // 128 # BLOCK_SIZE_ORIGINAL_TOKENS
self.max_blocks_per_request = max_blocks
self.block_tables: Dict[int, torch.Tensor] = {}
self.block_lens: Dict[int, torch.Tensor] = {}
for i, schema in self.schemas.items():
if schema.entries_per_block > 0:
self.block_tables[i] = torch.full(
(max_concurrent_requests, max_blocks), -1,
dtype=torch.int32, device=device,
)
self.block_lens[i] = torch.zeros(
(max_concurrent_requests,), dtype=torch.int32, device=device,
)
# ------------------------------------------------------------------
# Request lifecycle (called between captured graphs)
# ------------------------------------------------------------------
def admit_request(self) -> int:
"""Allocate a state cache slot. Returns the slot index."""
free = (self.request_slot_map == -1).nonzero(as_tuple=False)
if free.numel() == 0:
raise RuntimeError("max concurrent requests exceeded")
slot = int(free[0])
self.request_slot_map[slot] = slot
return slot
def release_request(self, slot: int) -> None:
"""Return state cache slot and all associated blocks to free lists."""
for layer_idx, alloc in self.allocators.items():
if alloc is None:
continue
table = self.block_tables[layer_idx]
lens = self.block_lens[layer_idx]
valid = int(lens[slot])
if valid > 0:
alloc.release(table[slot, :valid].clone())
lens[slot] = 0
table[slot].fill_(-1)
# Reset state cache slot.
for state in self.state_pools.values():
state.reset_slot(slot)
self.request_slot_map[slot] = -1
# ------------------------------------------------------------------
# Block allocation for compression flush (called between captures)
# ------------------------------------------------------------------
def allocate_block(self, layer_idx: int, request_slot: int) -> int:
"""Allocate one new classical block for a request. Returns block ID."""
alloc = self.allocators[layer_idx]
assert alloc is not None, f"layer {layer_idx} has no classical pool"
block_id = alloc.acquire(1)
bid = int(block_id[0])
# Append to the request's block table.
table = self.block_tables[layer_idx]
lens = self.block_lens[layer_idx]
pos = int(lens[request_slot])
assert pos < self.max_blocks_per_request, "block table overflow"
table[request_slot, pos] = bid
lens[request_slot] = pos + 1
return bid
# ------------------------------------------------------------------
# Per-forward handle construction (called INSIDE captured graph)
# ------------------------------------------------------------------
def acquire(
self,
layer_idx: int,
request_slots: torch.Tensor, # [batch] int32
positions: torch.Tensor, # [tokens] int32
request_ids: torch.Tensor, # [tokens] int32
) -> LayerCacheHandle:
"""Build the LayerCacheHandle for one layer's forward.
No allocation happens here — critical for cudagraph safety.
"""
paged = self.paged_pools[layer_idx]
state = self.state_pools[layer_idx]
if paged is not None:
# Pass the full tensors — no indexing, no allocation.
# The attention kernel indexes by request_slots internally.
block_table = self.block_tables[layer_idx]
block_lens = self.block_lens[layer_idx]
else:
block_table = None
block_lens = None
return LayerCacheHandle(
paged=paged,
state=state,
request_slots=request_slots,
positions=positions,
request_ids=request_ids,
block_table=block_table,
block_lens=block_lens,
)
# ------------------------------------------------------------------
# Diagnostics
# ------------------------------------------------------------------
def memory_bytes(self) -> int:
"""Total GPU memory used by all pools."""
total = 0
for pool in self.state_pools.values():
total += pool.memory_bytes()
for pool in self.paged_pools.values():
if pool is not None:
total += pool.memory_bytes()
for i, table in self.block_tables.items():
total += table.numel() * table.element_size()
total += self.block_lens[i].numel() * self.block_lens[i].element_size()
return total