Files
nvfp4-megamoe-kernel/dsv4/cache/state_cache.py

103 lines
4.3 KiB
Python
Raw Normal View History

KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
"""State cache: SWA window + uncompressed tail buffer.
One slot per active request. Slot index is fixed for a request's
lifetime the manager hands out slot indices at request admission
and reclaims them at completion.
Per paper §3.5.1: SWA and tail tokens are state-space-like they
depend only on the current position, not on a paged history. No
block table; a flat [max_requests, ...] tensor.
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
CSA b-stream lifecycle (paper eq.11-12):
After a CSA flush, the current a-stream (tail_ka/tail_za) becomes
the next flush's b-stream input (tail_kb/tail_zb). Both are sized
at m entries, not m-1. On first flush, tail_zb is filled with -1e9
so the softmax in the compressor naturally masks out the b-stream
(exp(-inf) = 0).
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
"""
from __future__ import annotations
import torch
from dsv4.cache.schema import LayerCacheSchema, AttentionType
class StateCachePool:
"""Per-layer state cache (SWA window + uncompressed tail).
Storage layout per slot:
swa_fp8: [n_win, head_dim - rope_dim] FP8 raw KV in window
swa_rope: [n_win, rope_dim] BF16 RoPE'd half
swa_inv: [n_win] FP32 per-token inv scale
swa_pos: [n_win] int32 absolute position
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
swa_head: scalar int32 ring buffer write head
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
tail_ka: [m_a, head_dim] BF16 current a-stream tokens
tail_za: [m_a, head_dim] BF16 current a-stream Z weights
tail_kb: [m_b, head_dim] BF16 previous a-stream kept as b-input (CSA only)
tail_zb: [m_b, head_dim] BF16 previous Z b-stream (CSA only, init to -1e9)
tail_len: scalar int32 how many entries in a-stream are valid
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
"""
def __init__(
self,
schema: LayerCacheSchema,
max_requests: int,
device: str = "cuda",
):
self.schema = schema
self.max_requests = max_requests
self.device = device
mr = max_requests
nw = schema.swa_window_size
hd = schema.entry_head_dim
rd = schema.rope_dim
fp8 = hd - rd
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
# SWA window — circular within each slot.
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
self.swa_fp8 = torch.zeros((mr, nw, fp8), dtype=torch.uint8, device=device)
self.swa_rope = torch.zeros((mr, nw, rd), dtype=torch.bfloat16, device=device)
self.swa_inv = torch.ones((mr, nw), dtype=torch.float32, device=device)
self.swa_pos = torch.full((mr, nw), -1, dtype=torch.int32, device=device)
self.swa_head = torch.zeros((mr,), dtype=torch.int32, device=device)
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
# Tail buffer — only for compressed layers.
m_a = schema.tail_buffer_size_a # m (CSA) or m' (HCA)
m_b = schema.tail_buffer_size_b # m (CSA only)
if m_a > 0:
self.tail_ka = torch.zeros((mr, m_a, hd), dtype=torch.bfloat16, device=device)
self.tail_za = torch.zeros((mr, m_a, hd), dtype=torch.bfloat16, device=device)
self.tail_len = torch.zeros((mr,), dtype=torch.int32, device=device)
if m_b > 0: # CSA: need b-stream
self.tail_kb = torch.zeros((mr, m_b, hd), dtype=torch.bfloat16, device=device)
# Paper §3.5.1: Z^b padded with -inf at first flush.
# Init to -1e9 so softmax naturally masks b-stream on first flush.
self.tail_zb = torch.full((mr, m_b, hd), -1e9, dtype=torch.bfloat16, device=device)
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
else:
self.tail_kb = None
self.tail_zb = None
else:
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
self.tail_ka = self.tail_za = None
self.tail_kb = self.tail_zb = None
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
self.tail_len = None
def reset_slot(self, slot: int) -> None:
"""Clear a request's state after completion."""
self.swa_pos[slot].fill_(-1)
self.swa_head[slot] = 0
if self.tail_len is not None:
self.tail_len[slot] = 0
Flush compressor: schema fix, prepare_forward, flush_write kernels, state rotation Schema fix (paper eq.11-12): CSA needs m entries for current a-stream AND m entries for previous b-stream (tail_buffer_size_a=4, tail_buffer_size_b=4). After flush, current a-stream becomes next flush b-stream input. HCA: tail_buffer_size_a=128, tail_buffer_size_b=0 (no b-stream). tail_zb initialized to -1e9 so softmax naturally masks b-stream on first flush (paper: Z^b padded with -inf, C^b with zeros). prepare_forward.py: Runs between captured graphs. Computes new compressed entries from position delta, pre-allocates blocks before the graph runs. Deterministic: entries_after - entries_before, ceil to block boundary. No allocation inside the captured graph. flush_write.cu — 4 kernels: flush_write_csa_kernel: BF16 -> FP8 E4M3 quantize + scatter compressed entry + FP4 NVFP4 indexer key write (16-element groups, E4M3 scale). One block per request, 128 threads. Amax reduction -> inv_scale. flush_write_hca_kernel: same minus indexer (no FP4 write). csa_rotate_state_kernel: after CSA flush, rotate a->b stream, clear a-stream, reset tail_len. hca_reset_state_kernel: after HCA flush, clear a-stream, reset tail_len. flush.py: Python orchestration. maybe_flush_csa/hca: always runs, kernels gate via valid_mask. Compressor produces entry, flush kernel quantize-scatters, state kernel rotates/resets. No host-side branching for cudagraph. All tests pass on B200: Schema: CSA tail_a=4 tail_b=4, HCA tail_a=128 tail_b=0 State: tail_zb initialized to -1e9, reset_slot preserves it prepare_forward: correct block allocation for position transitions HCA flush write: RoPE exact, FP8 <3.6% error, invalid mask no-op CSA flush write: RoPE exact, indexer FP4 keys written CSA state rotation: kb<-ka, zb<-za, ka/za zeroed, tail_len=0 HCA state reset: ka/za zeroed, tail_len=0
2026-05-22 00:25:47 +00:00
# Re-init tail_zb to -1e9 for CSA (paper §3.5.1 first-flush mask)
if self.tail_zb is not None:
self.tail_zb[slot].fill_(-1e9)
KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
def memory_bytes(self) -> int:
"""Total GPU memory used by this pool."""
total = 0
for name in ("swa_fp8", "swa_rope", "swa_inv", "swa_pos", "swa_head",
"tail_ka", "tail_za", "tail_kb", "tail_zb", "tail_len"):
t = getattr(self, name)
if t is not None:
total += t.numel() * t.element_size()
return total