Files
nvfp4-megamoe-kernel/dsv4/cache/paged_cache.py
biondizzle b4d58df620 KV Cache: schema, allocator, pools, manager, append_swa kernel
Complete KV cache substrate for DSV4 inference:

schema.py: Per-layer cache shape derived from LayerSpec.
  - CSA: 32 entries/block, 32 indexer entries, tail=3
  - HCA: 1 entry/block, no indexer, tail=127
  - SWA: no classical pool, no tail
  - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios)
  - compute_block_budget() for allocator sizing

allocator.py: Fixed-size block free-list.
  - GPU stack with pinned host top pointer
  - acquire/release between graph captures only
  - OOM raises on exhaustion

paged_cache.py: Per-layer classical KV storage.
  - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4)
  - Per-entry inverse scale for FP8 dequant
  - FP4 indexer keys for CSA layers (NVFP4 scheme)
  - memory_bytes() tracking

state_cache.py: Per-layer SWA window + tail buffer.
  - Ring buffer with position tracking (swa_head, swa_pos)
  - CSA: dual streams (ka/za/kb/zb) for overlapping compression
  - HCA: single stream (ka/za only)
  - SWA: no tail buffer
  - reset_slot() for request completion

handle.py: LayerCacheHandle — typed per-call view.
  - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view()
  - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe)
  - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures

manager.py: KVCacheManager — owns everything.
  - Per-layer schema, pool, and allocator construction
  - admit_request()/release_request() lifecycle
  - allocate_block() for compression flush
  - acquire() returns LayerCacheHandle (zero-alloc)

append_swa.cu: Native kernel for SWA writes.
  - One block per token, 128 threads per block
  - Warp-level amax reduction, BF16->FP8 E4M3 quantization
  - Atomic ring buffer head increment
  - FP8/BF16 split write + inv_scale + position metadata
  - FP8 round-trip: <3.6% relative error
  - RoPE half: exact match (no quantization)

All tests pass on B200:
  - Schema correctness for CSA/HCA/SWA
  - Allocator acquire/release/OOM
  - Pool shapes match architecture spec
  - Manager lifecycle (admit/release/recycle/exhaustion)
  - Zero-alloc acquire() (cudagraph safe)
  - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00

92 lines
3.3 KiB
Python

"""Storage for one layer's classical paged KV cache.
Layout per block:
entries: [num_blocks, entries_per_block, head_dim - rope_dim] FP8 (uint8 view)
entries_r: [num_blocks, entries_per_block, rope_dim] BF16
inv_scale: [num_blocks, entries_per_block] FP32
The FP8/BF16 split mirrors paper §2.3.4 ("BF16 for RoPE dims, FP8 for
the rest"). The kernel reads both halves and concatenates in registers.
For CSA layers, a parallel pool stores indexer keys at the same block
granularity — same block ID maps to a block in both pools.
"""
from __future__ import annotations
from typing import Optional
import torch
from dsv4.cache.schema import LayerCacheSchema
class PagedKVPool:
"""Per-layer classical paged KV storage.
Indexed by [physical_block_id, slot_in_block, ...].
Both compressed entries and indexer keys (if applicable) are
indexed by the SAME physical_block_id so a CSA layer's two pools
share the block table.
"""
def __init__(
self,
schema: LayerCacheSchema,
num_blocks: int,
device: str = "cuda",
):
self.schema = schema
self.num_blocks = num_blocks
self.device = device
nb = num_blocks
epb = schema.entries_per_block
hd = schema.entry_head_dim
rd = schema.rope_dim
fp8_dim = hd - rd
# ---- Compressed entries ----
# FP8 stored as uint8 (we view as float8_e4m3fn at read time).
self.entries_fp8 = torch.zeros(
(nb, epb, fp8_dim), dtype=torch.uint8, device=device,
)
# BF16 RoPE'd half — no quantization.
self.entries_rope = torch.zeros(
(nb, epb, rd), dtype=torch.bfloat16, device=device,
)
# Per-entry inverse scale (for FP8 dequant in attention kernel).
self.inv_scale = torch.ones(
(nb, epb), dtype=torch.float32, device=device,
)
# ---- Indexer keys (CSA only) ----
if schema.indexer_entries_per_block > 0:
i_epb = schema.indexer_entries_per_block
i_hd = schema.indexer_head_dim
# Indexer QK is FP4 per paper §2.3.4 — but we store the keys
# post-quant. uint8 = 2 FP4 packed per byte.
self.indexer_keys_fp4 = torch.zeros(
(nb, i_epb, i_hd // 2), dtype=torch.uint8, device=device,
)
# Per-block-vector scale for the FP4 (one E4M3 scalar per
# 16-element group, per the NVFP4 quantization scheme).
self.indexer_scale = torch.ones(
(nb, i_epb, i_hd // 16),
dtype=torch.float8_e4m3fn, device=device,
)
self.indexer_global_scale = torch.ones(
(nb,), dtype=torch.float32, device=device,
)
else:
self.indexer_keys_fp4 = None
self.indexer_scale = None
self.indexer_global_scale = None
def memory_bytes(self) -> int:
"""Total GPU memory used by this pool."""
total = 0
for name in ("entries_fp8", "entries_rope", "inv_scale",
"indexer_keys_fp4", "indexer_scale", "indexer_global_scale"):
t = getattr(self, name)
if t is not None:
total += t.numel() * t.element_size()
return total