Files
nvfp4-megamoe-kernel/dsv4/cache/schema.py

126 lines
4.6 KiB
Python
Raw Normal View History

KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
"""Per-layer KV cache shape.
Computed once per layer at engine startup from the LayerSpec. The
schema is what tells the allocator how big each pool slot is and what
sub-regions exist (compressed entries / indexer keys / SWA window /
uncompressed tail).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from dsv4.model.config import DSV4Config
from dsv4.model.layer_schedule import LayerSpec, AttentionType
# Block size is invariant for DSV4 — derived from compression ratios.
# lcm(m, m') = lcm(4, 128) = 128 original tokens per block.
# Holds 128/4 = 32 CSA entries OR 128/128 = 1 HCA entry per block.
BLOCK_SIZE_ORIGINAL_TOKENS = 128
@dataclass(frozen=True)
class LayerCacheSchema:
"""Cache layout for one transformer layer.
Fields with `_per_block` are the dimensions of one block in the
classical paged pool. `_per_state_slot` are dimensions of one
request's slot in the state cache.
All sizes are in number of entries bytes come from the dtypes.
"""
layer_idx: int
attn_type: AttentionType
# ---- Classical paged cache (compressed entries) ----
entries_per_block: int
entry_head_dim: int
rope_dim: int
# ---- Indexer pool (CSA only) ----
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
indexer_entries_per_block: int
indexer_head_dim: int
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
# ---- State cache (SWA window + uncompressed tail) ----
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
swa_window_size: int
# CSA: paper eq.11-12, the i-th flush uses Ca[m*i:m*(i+1)] and
# Cb[m*(i-1):m*i]. After flush, current a-stream becomes next b-stream.
# So we need m entries for current a-stream AND m entries for previous
# b-stream. Total tail = 2*m for CSA.
tail_buffer_size_a: int # m (CSA) or m' (HCA) — current tokens
tail_buffer_size_b: int # m (CSA only) — previous block's a-stream kept as b-input
# Per-token inverse scale storage (for FP8 dequant).
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
needs_inv_scale: bool = True
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
@property
def tail_buffer_size(self) -> int:
"""Total tail entries (for backward compat with schema consumers)."""
return self.tail_buffer_size_a + self.tail_buffer_size_b
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
"""Derive cache schema for a single layer from architectural config."""
if spec.attn == AttentionType.CSA:
return LayerCacheSchema(
layer_idx=spec.layer_idx,
attn_type=AttentionType.CSA,
entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio,
entry_head_dim=config.head_dim,
rope_dim=config.rope_dim,
indexer_entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio,
indexer_head_dim=config.indexer_head_dim,
swa_window_size=config.sliding_window,
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
tail_buffer_size_a=config.csa_compression_ratio, # m=4 current
tail_buffer_size_b=config.csa_compression_ratio, # m=4 previous (b-stream)
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
)
elif spec.attn == AttentionType.HCA:
return LayerCacheSchema(
layer_idx=spec.layer_idx,
attn_type=AttentionType.HCA,
entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.hca_compression_ratio,
entry_head_dim=config.head_dim,
rope_dim=config.rope_dim,
indexer_entries_per_block=0,
indexer_head_dim=0,
swa_window_size=config.sliding_window,
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
tail_buffer_size_a=config.hca_compression_ratio, # m'=128 current
tail_buffer_size_b=0, # HCA has no b-stream
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
)
else: # SWA-only
return LayerCacheSchema(
layer_idx=spec.layer_idx,
attn_type=AttentionType.SWA,
entries_per_block=0,
entry_head_dim=config.head_dim,
rope_dim=config.rope_dim,
indexer_entries_per_block=0,
indexer_head_dim=0,
swa_window_size=config.sliding_window,
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
tail_buffer_size_a=0,
tail_buffer_size_b=0,
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
)
def compute_block_budget(
config: DSV4Config,
schedule: list[LayerSpec],
max_context_tokens: int,
max_concurrent_requests: int,
) -> dict[str, int]:
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
"""Compute per-layer-type block counts for the allocator."""
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
blocks_per_request = max_context_tokens // BLOCK_SIZE_ORIGINAL_TOKENS
headroom = 1.10
result = {}
for spec in schedule:
if spec.attn == AttentionType.CSA:
key = "csa"
elif spec.attn == AttentionType.HCA:
key = "hca"
else:
continue
total = int(max_concurrent_requests * blocks_per_request * headroom)
result[key] = max(result.get(key, 0), total)
return result