Schema fix (paper eq.11-12):
CSA needs m entries for current a-stream AND m entries for previous
b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush,
current a-stream becomes next flush b-stream input.
HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream).
tail_zb initialized to -1e9 so softmax naturally masks b-stream on
first flush (paper: Z^b padded with -inf, C^b with zeros).
prepare_forward.py:
Runs between captured graphs. Computes new compressed entries from
position delta, pre-allocates blocks before the graph runs.
Deterministic: entries_after - entries_before, ceil to block boundary.
No allocation inside the captured graph.
flush_write.cu — 4 kernels:
flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed
entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale).
One block per request, 128 threads. Amax reduction -> inv_scale.
flush_write_hca_kernel: same minus indexer (no FP4 write).
csa_rotate_state_kernel: after CSA flush, rotate a->b stream,
clear a-stream, reset tail_len.
hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len.
flush.py: Python orchestration.
maybe_flush_csa/hca: always runs, kernels gate via valid_mask.
Compressor produces entry, flush kernel quantize-scatters, state
kernel rotates/resets. No host-side branching for cudagraph.
All tests pass on B200:
Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0
State: tail_zb initialized to -1e9, reset_slot preserves it
prepare_forward: correct block allocation for position transitions
HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op
CSA flush write: RoPE exact, indexer FP4 keys written
CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0
HCA state reset: ka/za zeroed, tail_len=0
126 lines
4.6 KiB
Python
126 lines
4.6 KiB
Python
"""Per-layer KV cache shape.
|
|
|
|
Computed once per layer at engine startup from the LayerSpec. The
|
|
schema is what tells the allocator how big each pool slot is and what
|
|
sub-regions exist (compressed entries / indexer keys / SWA window /
|
|
uncompressed tail).
|
|
"""
|
|
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
from dsv4.model.config import DSV4Config
|
|
from dsv4.model.layer_schedule import LayerSpec, AttentionType
|
|
|
|
|
|
# Block size is invariant for DSV4 — derived from compression ratios.
|
|
# lcm(m, m') = lcm(4, 128) = 128 original tokens per block.
|
|
# Holds 128/4 = 32 CSA entries OR 128/128 = 1 HCA entry per block.
|
|
BLOCK_SIZE_ORIGINAL_TOKENS = 128
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LayerCacheSchema:
|
|
"""Cache layout for one transformer layer.
|
|
|
|
Fields with `_per_block` are the dimensions of one block in the
|
|
classical paged pool. `_per_state_slot` are dimensions of one
|
|
request's slot in the state cache.
|
|
|
|
All sizes are in number of entries — bytes come from the dtypes.
|
|
"""
|
|
layer_idx: int
|
|
attn_type: AttentionType
|
|
|
|
# ---- Classical paged cache (compressed entries) ----
|
|
entries_per_block: int
|
|
entry_head_dim: int
|
|
rope_dim: int
|
|
|
|
# ---- Indexer pool (CSA only) ----
|
|
indexer_entries_per_block: int
|
|
indexer_head_dim: int
|
|
|
|
# ---- State cache (SWA window + uncompressed tail) ----
|
|
swa_window_size: int
|
|
|
|
# CSA: paper eq.11-12, the i-th flush uses Ca[m*i:m*(i+1)] and
|
|
# Cb[m*(i-1):m*i]. After flush, current a-stream becomes next b-stream.
|
|
# So we need m entries for current a-stream AND m entries for previous
|
|
# b-stream. Total tail = 2*m for CSA.
|
|
tail_buffer_size_a: int # m (CSA) or m' (HCA) — current tokens
|
|
tail_buffer_size_b: int # m (CSA only) — previous block's a-stream kept as b-input
|
|
|
|
# Per-token inverse scale storage (for FP8 dequant).
|
|
needs_inv_scale: bool = True
|
|
|
|
@property
|
|
def tail_buffer_size(self) -> int:
|
|
"""Total tail entries (for backward compat with schema consumers)."""
|
|
return self.tail_buffer_size_a + self.tail_buffer_size_b
|
|
|
|
|
|
def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
|
|
"""Derive cache schema for a single layer from architectural config."""
|
|
if spec.attn == AttentionType.CSA:
|
|
return LayerCacheSchema(
|
|
layer_idx=spec.layer_idx,
|
|
attn_type=AttentionType.CSA,
|
|
entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio,
|
|
entry_head_dim=config.head_dim,
|
|
rope_dim=config.rope_dim,
|
|
indexer_entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio,
|
|
indexer_head_dim=config.indexer_head_dim,
|
|
swa_window_size=config.sliding_window,
|
|
tail_buffer_size_a=config.csa_compression_ratio, # m=4 current
|
|
tail_buffer_size_b=config.csa_compression_ratio, # m=4 previous (b-stream)
|
|
)
|
|
elif spec.attn == AttentionType.HCA:
|
|
return LayerCacheSchema(
|
|
layer_idx=spec.layer_idx,
|
|
attn_type=AttentionType.HCA,
|
|
entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.hca_compression_ratio,
|
|
entry_head_dim=config.head_dim,
|
|
rope_dim=config.rope_dim,
|
|
indexer_entries_per_block=0,
|
|
indexer_head_dim=0,
|
|
swa_window_size=config.sliding_window,
|
|
tail_buffer_size_a=config.hca_compression_ratio, # m'=128 current
|
|
tail_buffer_size_b=0, # HCA has no b-stream
|
|
)
|
|
else: # SWA-only
|
|
return LayerCacheSchema(
|
|
layer_idx=spec.layer_idx,
|
|
attn_type=AttentionType.SWA,
|
|
entries_per_block=0,
|
|
entry_head_dim=config.head_dim,
|
|
rope_dim=config.rope_dim,
|
|
indexer_entries_per_block=0,
|
|
indexer_head_dim=0,
|
|
swa_window_size=config.sliding_window,
|
|
tail_buffer_size_a=0,
|
|
tail_buffer_size_b=0,
|
|
)
|
|
|
|
|
|
def compute_block_budget(
|
|
config: DSV4Config,
|
|
schedule: list[LayerSpec],
|
|
max_context_tokens: int,
|
|
max_concurrent_requests: int,
|
|
) -> dict[str, int]:
|
|
"""Compute per-layer-type block counts for the allocator."""
|
|
blocks_per_request = max_context_tokens // BLOCK_SIZE_ORIGINAL_TOKENS
|
|
headroom = 1.10
|
|
result = {}
|
|
for spec in schedule:
|
|
if spec.attn == AttentionType.CSA:
|
|
key = "csa"
|
|
elif spec.attn == AttentionType.HCA:
|
|
key = "hca"
|
|
else:
|
|
continue
|
|
total = int(max_concurrent_requests * blocks_per_request * headroom)
|
|
result[key] = max(result.get(key, 0), total)
|
|
return result
|