Files
nvfp4-megamoe-kernel/dsv4/cache/allocator.py
biondizzle b4d58df620 KV Cache: schema, allocator, pools, manager, append_swa kernel
Complete KV cache substrate for DSV4 inference:

schema.py: Per-layer cache shape derived from LayerSpec.
  - CSA: 32 entries/block, 32 indexer entries, tail=3
  - HCA: 1 entry/block, no indexer, tail=127
  - SWA: no classical pool, no tail
  - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios)
  - compute_block_budget() for allocator sizing

allocator.py: Fixed-size block free-list.
  - GPU stack with pinned host top pointer
  - acquire/release between graph captures only
  - OOM raises on exhaustion

paged_cache.py: Per-layer classical KV storage.
  - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4)
  - Per-entry inverse scale for FP8 dequant
  - FP4 indexer keys for CSA layers (NVFP4 scheme)
  - memory_bytes() tracking

state_cache.py: Per-layer SWA window + tail buffer.
  - Ring buffer with position tracking (swa_head, swa_pos)
  - CSA: dual streams (ka/za/kb/zb) for overlapping compression
  - HCA: single stream (ka/za only)
  - SWA: no tail buffer
  - reset_slot() for request completion

handle.py: LayerCacheHandle — typed per-call view.
  - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view()
  - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe)
  - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures

manager.py: KVCacheManager — owns everything.
  - Per-layer schema, pool, and allocator construction
  - admit_request()/release_request() lifecycle
  - allocate_block() for compression flush
  - acquire() returns LayerCacheHandle (zero-alloc)

append_swa.cu: Native kernel for SWA writes.
  - One block per token, 128 threads per block
  - Warp-level amax reduction, BF16->FP8 E4M3 quantization
  - Atomic ring buffer head increment
  - FP8/BF16 split write + inv_scale + position metadata
  - FP8 round-trip: <3.6% relative error
  - RoPE half: exact match (no quantization)

All tests pass on B200:
  - Schema correctness for CSA/HCA/SWA
  - Allocator acquire/release/OOM
  - Pool shapes match architecture spec
  - Manager lifecycle (admit/release/recycle/exhaustion)
  - Zero-alloc acquire() (cudagraph safe)
  - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00

57 lines
2.0 KiB
Python

"""Fixed-size block allocator for the classical paged KV cache.
One BlockAllocator per layer per "pool kind" (classical / indexer).
Total blocks are sized at engine startup. Blocks are recycled on
request completion.
Cudagraph-safety: allocation can't happen inside a captured graph
(allocation rate is per-request not per-token). The contract is:
- acquire() called between graph captures.
- release() called between graph captures.
- read access (via block table) happens INSIDE captured graphs.
"""
from __future__ import annotations
import torch
class BlockAllocator:
def __init__(
self,
num_total_blocks: int,
device: str = "cuda",
):
self.num_total_blocks = num_total_blocks
self.device = device
# Free-list as a GPU stack: ids[0..top-1] holds free block IDs.
# `top` lives in pinned host memory so we can read it without a
# device sync (it's modified only between graph captures).
self.free_ids = torch.arange(
num_total_blocks, dtype=torch.int32, device=device,
)
self.top_cpu = torch.tensor([num_total_blocks], dtype=torch.int32, pin_memory=True)
@property
def num_free(self) -> int:
return int(self.top_cpu[0])
def acquire(self, n: int) -> torch.Tensor:
"""Return a tensor of `n` block IDs. Called between captures."""
top = int(self.top_cpu[0])
if n > top:
raise RuntimeError(
f"KV cache OOM: requested {n} blocks, {top} available "
f"(of {self.num_total_blocks} total)"
)
new_top = top - n
ids = self.free_ids[new_top:top].clone() # snapshot
self.top_cpu[0] = new_top
return ids
def release(self, ids: torch.Tensor) -> None:
"""Return blocks to the free list. Called between captures."""
n = ids.numel()
top = int(self.top_cpu[0])
self.free_ids[top:top + n] = ids.to(device=self.device)
self.top_cpu[0] = top + n