Files
nvfp4-megamoe-kernel/dsv4/kernels/cache/append_swa.py
biondizzle b4d58df620 KV Cache: schema, allocator, pools, manager, append_swa kernel
Complete KV cache substrate for DSV4 inference:

schema.py: Per-layer cache shape derived from LayerSpec.
  - CSA: 32 entries/block, 32 indexer entries, tail=3
  - HCA: 1 entry/block, no indexer, tail=127
  - SWA: no classical pool, no tail
  - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios)
  - compute_block_budget() for allocator sizing

allocator.py: Fixed-size block free-list.
  - GPU stack with pinned host top pointer
  - acquire/release between graph captures only
  - OOM raises on exhaustion

paged_cache.py: Per-layer classical KV storage.
  - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4)
  - Per-entry inverse scale for FP8 dequant
  - FP4 indexer keys for CSA layers (NVFP4 scheme)
  - memory_bytes() tracking

state_cache.py: Per-layer SWA window + tail buffer.
  - Ring buffer with position tracking (swa_head, swa_pos)
  - CSA: dual streams (ka/za/kb/zb) for overlapping compression
  - HCA: single stream (ka/za only)
  - SWA: no tail buffer
  - reset_slot() for request completion

handle.py: LayerCacheHandle — typed per-call view.
  - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view()
  - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe)
  - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures

manager.py: KVCacheManager — owns everything.
  - Per-layer schema, pool, and allocator construction
  - admit_request()/release_request() lifecycle
  - allocate_block() for compression flush
  - acquire() returns LayerCacheHandle (zero-alloc)

append_swa.cu: Native kernel for SWA writes.
  - One block per token, 128 threads per block
  - Warp-level amax reduction, BF16->FP8 E4M3 quantization
  - Atomic ring buffer head increment
  - FP8/BF16 split write + inv_scale + position metadata
  - FP8 round-trip: <3.6% relative error
  - RoPE half: exact match (no quantization)

All tests pass on B200:
  - Schema correctness for CSA/HCA/SWA
  - Allocator acquire/release/OOM
  - Pool shapes match architecture spec
  - Manager lifecycle (admit/release/recycle/exhaustion)
  - Zero-alloc acquire() (cudagraph safe)
  - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00

52 lines
1.7 KiB
Python

"""Python wrapper for the append_swa CUDA kernel.
Writes raw BF16 KV into the FP8/BF16 split state cache layout.
Quantizes the non-RoPE half BF16 -> FP8 (E4M3 amax-based scaling),
writes the RoPE half as-is, computes per-token inverse scale, and
updates the ring buffer head + position field.
One block per token. Threads cooperatively:
1. Compute amax over fp8-dim elements (warp reduce).
2. Quantize BF16 -> FP8 with per-token scale.
3. Write FP8 entries + BF16 RoPE entries + inv_scale + position.
4. Atomic increment ring buffer head.
"""
import os
import torch
from torch.utils.cpp_extension import load
_kernel_module = None
def _get_kernel_module():
global _kernel_module
if _kernel_module is not None:
return _kernel_module
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda")
_kernel_module = load(
name="append_swa",
sources=[os.path.join(kernel_dir, "append_swa.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _kernel_module
def append_swa_kernel(
raw_kv: torch.Tensor, # (T, head_dim) BF16
request_slots: torch.Tensor, # (T,) int32
positions: torch.Tensor, # (T,) int32
swa_fp8: torch.Tensor, # (max_req, n_win, fp8_dim) uint8
swa_rope: torch.Tensor, # (max_req, n_win, rope_dim) BF16
swa_inv: torch.Tensor, # (max_req, n_win) FP32
swa_pos: torch.Tensor, # (max_req, n_win) int32
swa_head: torch.Tensor, # (max_req,) int32
rope_dim: int,
):
mod = _get_kernel_module()
mod.append_swa(
raw_kv, request_slots, positions,
swa_fp8, swa_rope, swa_inv, swa_pos, swa_head,
rope_dim,
)