Files
nvfp4-megamoe-kernel/dsv4/cache/handle.py
biondizzle b4d58df620 KV Cache: schema, allocator, pools, manager, append_swa kernel
Complete KV cache substrate for DSV4 inference:

schema.py: Per-layer cache shape derived from LayerSpec.
  - CSA: 32 entries/block, 32 indexer entries, tail=3
  - HCA: 1 entry/block, no indexer, tail=127
  - SWA: no classical pool, no tail
  - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios)
  - compute_block_budget() for allocator sizing

allocator.py: Fixed-size block free-list.
  - GPU stack with pinned host top pointer
  - acquire/release between graph captures only
  - OOM raises on exhaustion

paged_cache.py: Per-layer classical KV storage.
  - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4)
  - Per-entry inverse scale for FP8 dequant
  - FP4 indexer keys for CSA layers (NVFP4 scheme)
  - memory_bytes() tracking

state_cache.py: Per-layer SWA window + tail buffer.
  - Ring buffer with position tracking (swa_head, swa_pos)
  - CSA: dual streams (ka/za/kb/zb) for overlapping compression
  - HCA: single stream (ka/za only)
  - SWA: no tail buffer
  - reset_slot() for request completion

handle.py: LayerCacheHandle — typed per-call view.
  - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view()
  - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe)
  - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures

manager.py: KVCacheManager — owns everything.
  - Per-layer schema, pool, and allocator construction
  - admit_request()/release_request() lifecycle
  - allocate_block() for compression flush
  - acquire() returns LayerCacheHandle (zero-alloc)

append_swa.cu: Native kernel for SWA writes.
  - One block per token, 128 threads per block
  - Warp-level amax reduction, BF16->FP8 E4M3 quantization
  - Atomic ring buffer head increment
  - FP8/BF16 split write + inv_scale + position metadata
  - FP8 round-trip: <3.6% relative error
  - RoPE half: exact match (no quantization)

All tests pass on B200:
  - Schema correctness for CSA/HCA/SWA
  - Allocator acquire/release/OOM
  - Pool shapes match architecture spec
  - Manager lifecycle (admit/release/recycle/exhaustion)
  - Zero-alloc acquire() (cudagraph safe)
  - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00

143 lines
5.0 KiB
Python

"""LayerCacheHandle — typed per-call view onto one layer's cache.
Constructed by KVCacheManager.acquire() once per layer per forward.
Holds tensor references and integer indices; no allocation. Methods
expose the operations AttentionSubBlock needs without exposing the
underlying storage layout.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from dsv4.cache.paged_cache import PagedKVPool
from dsv4.cache.state_cache import StateCachePool
@dataclass
class LayerCacheHandle:
"""Read/write interface for one layer's cache.
The fields are the resolved indices and tensor refs for THIS call's
batch of requests. AttentionSubBlock never sees raw pool tensors.
"""
# Pool references (shared across handles — never mutated).
paged: Optional["PagedKVPool"]
state: "StateCachePool"
# Per-call indices.
request_slots: torch.Tensor # [batch] int32 — state cache slot per request
positions: torch.Tensor # [tokens] int32 — absolute position per token
request_ids: torch.Tensor # [tokens] int32 — which request each token belongs to
# Block table for the classical pool (None for SWA-only layers).
# Shape: [batch, max_logical_blocks] int32. -1 padding for unused entries.
block_table: Optional[torch.Tensor]
# Number of valid blocks per request (excludes padding).
block_lens: Optional[torch.Tensor]
# ------------------------------------------------------------------
# Methods called by AttentionSubBlock
# ------------------------------------------------------------------
def write_swa(
self,
raw_kv: torch.Tensor, # (T, head_dim) BF16
) -> None:
"""Write raw KV into the SWA ring buffer AND tail compression buffer.
Both regions get the same tokens — SWA consumes the last n_win,
the tail accumulates until it can flush.
"""
from dsv4.kernels.cache.append_swa import append_swa_kernel
append_swa_kernel(
raw_kv=raw_kv,
request_slots=self.request_slots,
positions=self.positions,
swa_fp8=self.state.swa_fp8,
swa_rope=self.state.swa_rope,
swa_inv=self.state.swa_inv,
swa_pos=self.state.swa_pos,
swa_head=self.state.swa_head,
rope_dim=self.state.schema.rope_dim,
)
def flush_compression(
self,
compressed: torch.Tensor, # (T_flush, head_dim) BF16 — newly produced
indexer_keys: Optional[torch.Tensor] = None,
) -> None:
"""Promote pending tail tokens into the classical pool.
Called by the compressor when the tail buffer has enough tokens.
Allocates a new block if the latest block is full.
Block allocation requires going outside the captured graph — in
a fully-captured decode this is rare (once per m or m' tokens),
so we make it explicit. The manager has the contract.
"""
raise NotImplementedError("see kernels/cache/flush_compression.py")
def read_swa_view(self) -> "SWAView":
"""Return a typed view of the SWA window for this batch."""
return SWAView(
fp8=self.state.swa_fp8,
rope=self.state.swa_rope,
inv_scale=self.state.swa_inv,
positions=self.state.swa_pos,
head=self.state.swa_head,
slots=self.request_slots,
)
def read_classical_view(self) -> "ClassicalView":
"""Return a typed view of compressed entries for this batch."""
assert self.paged is not None, "SWA-only layers have no classical cache"
return ClassicalView(
entries_fp8=self.paged.entries_fp8,
entries_rope=self.paged.entries_rope,
inv_scale=self.paged.inv_scale,
block_table=self.block_table,
block_lens=self.block_lens,
)
def read_indexer_view(self) -> "IndexerView":
"""CSA-only. Returns FP4 indexer keys with their scales."""
assert self.paged is not None and self.paged.indexer_keys_fp4 is not None
return IndexerView(
keys_fp4=self.paged.indexer_keys_fp4,
scale=self.paged.indexer_scale,
global_scale=self.paged.indexer_global_scale,
block_table=self.block_table,
block_lens=self.block_lens,
)
# Typed views — simple dataclasses, no logic. The FMHA / indexer / SWA
# kernels accept these to keep their signatures clean.
@dataclass
class SWAView:
fp8: torch.Tensor
rope: torch.Tensor
inv_scale: torch.Tensor
positions: torch.Tensor
head: torch.Tensor
slots: torch.Tensor
@dataclass
class ClassicalView:
entries_fp8: torch.Tensor
entries_rope: torch.Tensor
inv_scale: torch.Tensor
block_table: torch.Tensor
block_lens: torch.Tensor
@dataclass
class IndexerView:
keys_fp4: torch.Tensor
scale: torch.Tensor
global_scale: torch.Tensor
block_table: torch.Tensor
block_lens: torch.Tensor