Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
130 lines
4.8 KiB
Python
130 lines
4.8 KiB
Python
"""Per-layer KV cache shape.
|
|
|
|
Computed once per layer at engine startup from the LayerSpec. The
|
|
schema is what tells the allocator how big each pool slot is and what
|
|
sub-regions exist (compressed entries / indexer keys / SWA window /
|
|
uncompressed tail).
|
|
"""
|
|
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
from dsv4.model.config import DSV4Config
|
|
from dsv4.model.layer_schedule import LayerSpec, AttentionType
|
|
|
|
|
|
# Block size is invariant for DSV4 — derived from compression ratios.
|
|
# lcm(m, m') = lcm(4, 128) = 128 original tokens per block.
|
|
# Holds 128/4 = 32 CSA entries OR 128/128 = 1 HCA entry per block.
|
|
BLOCK_SIZE_ORIGINAL_TOKENS = 128
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LayerCacheSchema:
|
|
"""Cache layout for one transformer layer.
|
|
|
|
Fields with `_per_block` are the dimensions of one block in the
|
|
classical paged pool. `_per_state_slot` are dimensions of one
|
|
request's slot in the state cache.
|
|
|
|
All sizes are in number of entries — bytes come from the dtypes.
|
|
"""
|
|
layer_idx: int
|
|
attn_type: AttentionType
|
|
|
|
# ---- Classical paged cache (compressed entries) ----
|
|
# Number of compressed entries in one block of BLOCK_SIZE_ORIGINAL_TOKENS
|
|
# original tokens. For HCA m'=128 this is 1; for CSA m=4 this is 32.
|
|
# SWA-only layers have no classical pool — entries_per_block = 0.
|
|
entries_per_block: int
|
|
# Width of one entry (head_dim).
|
|
entry_head_dim: int
|
|
# RoPE-applied dimensions (BF16). Others FP8.
|
|
rope_dim: int
|
|
|
|
# ---- Indexer pool (CSA only) ----
|
|
# Compressed indexer keys, one per compressed entry.
|
|
indexer_entries_per_block: int # 32 for CSA, 0 for HCA/SWA
|
|
indexer_head_dim: int # 128 for CSA, 0 for others
|
|
|
|
# ---- State cache (SWA window + uncompressed tail) ----
|
|
swa_window_size: int # 128 for all layer types
|
|
# Uncompressed tail buffer — needed only for compressed layers.
|
|
# CSA: up to m-1 = 3 pending tokens before flushing compression.
|
|
# HCA: up to m'-1 = 127 pending tokens.
|
|
# SWA-only: no tail (no compression branch).
|
|
tail_buffer_size: int
|
|
|
|
# Per-token inverse scale storage (for FP8 dequant). One FP32 scalar
|
|
# per stored entry/window-slot.
|
|
needs_inv_scale: bool = True
|
|
|
|
|
|
def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
|
|
"""Derive cache schema for a single layer from architectural config."""
|
|
if spec.attn == AttentionType.CSA:
|
|
return LayerCacheSchema(
|
|
layer_idx=spec.layer_idx,
|
|
attn_type=AttentionType.CSA,
|
|
entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio,
|
|
entry_head_dim=config.head_dim,
|
|
rope_dim=config.rope_dim,
|
|
indexer_entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio,
|
|
indexer_head_dim=config.indexer_head_dim,
|
|
swa_window_size=config.sliding_window,
|
|
tail_buffer_size=config.csa_compression_ratio - 1,
|
|
)
|
|
elif spec.attn == AttentionType.HCA:
|
|
return LayerCacheSchema(
|
|
layer_idx=spec.layer_idx,
|
|
attn_type=AttentionType.HCA,
|
|
entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.hca_compression_ratio,
|
|
entry_head_dim=config.head_dim,
|
|
rope_dim=config.rope_dim,
|
|
indexer_entries_per_block=0,
|
|
indexer_head_dim=0,
|
|
swa_window_size=config.sliding_window,
|
|
tail_buffer_size=config.hca_compression_ratio - 1,
|
|
)
|
|
else: # SWA-only
|
|
return LayerCacheSchema(
|
|
layer_idx=spec.layer_idx,
|
|
attn_type=AttentionType.SWA,
|
|
entries_per_block=0,
|
|
entry_head_dim=config.head_dim,
|
|
rope_dim=config.rope_dim,
|
|
indexer_entries_per_block=0,
|
|
indexer_head_dim=0,
|
|
swa_window_size=config.sliding_window,
|
|
tail_buffer_size=0,
|
|
)
|
|
|
|
|
|
def compute_block_budget(
|
|
config: DSV4Config,
|
|
schedule: list[LayerSpec],
|
|
max_context_tokens: int,
|
|
max_concurrent_requests: int,
|
|
) -> dict[str, int]:
|
|
"""Compute per-layer-type block counts for the allocator.
|
|
|
|
Returns {layer_type: num_blocks} where layer_type is 'csa' or 'hca'.
|
|
SWA-only layers need no classical blocks.
|
|
|
|
Block budget = max_concurrent_requests * (max_context / BLOCK_SIZE).
|
|
Add 10% headroom for fragmentation.
|
|
"""
|
|
blocks_per_request = max_context_tokens // BLOCK_SIZE_ORIGINAL_TOKENS
|
|
headroom = 1.10
|
|
result = {}
|
|
for spec in schedule:
|
|
if spec.attn == AttentionType.CSA:
|
|
key = "csa"
|
|
elif spec.attn == AttentionType.HCA:
|
|
key = "hca"
|
|
else:
|
|
continue
|
|
total = int(max_concurrent_requests * blocks_per_request * headroom)
|
|
result[key] = max(result.get(key, 0), total)
|
|
return result
|