Files
nvfp4-megamoe-kernel/dsv4/_archive/cache/state_cache.py
biondizzle f3b551956d Cleanup Step 2: Archive Lineage P code, fix broken imports
- Move dead dsv4/ modules to dsv4/_archive/ (52 files)
  - model/{dsv4,mtp,layer,layer_schedule}
  - layers/{embedding,attention,ffn,norm} (kept linear,mhc,router,moe,shared_expert,grouped_linear - live)
  - cache/*, kernels/cache/*, kernels/indexer/{csa_indexer,score_topk,compute_valid_lens}
  - kernels/router/{nvfp4_fused_router,dense_router_decode_kernel,dense_router_prefill}
  - ops/{topk,topk_select,rope,router}, loader/{hf_checkpoint,layout_convert}
  - reference/{attention,compressor,csa_attention,moe_pipeline}
  - kernels/compressor/{compress_tail,csa_hca}
- Restore dsv4/ops/{router,custom_ops}.py (needed by live layers)
- Fix dsv4/kernels/{indexer,compressor,attention}/__init__.py (removed broken imports)
- Remove preload_all() from loader.py (dead, referenced nonexistent .cu file)
- Fix loader.py docstring (fused_amax_quantize_nvfp4 → quantize_nvfp4_from_buffer)
- Move broken tests to tests/e2e_archive/
  - test_fused_router, production_values_test, e2e/{one_layer,model_construction,csa_hca}
- vLLM has 0 imports of dsv4 (Step 0 confirmed)
2026-06-02 19:27:07 +00:00

103 lines
4.3 KiB
Python

"""State cache: SWA window + uncompressed tail buffer.
One slot per active request. Slot index is fixed for a request's
lifetime — the manager hands out slot indices at request admission
and reclaims them at completion.
Per paper §3.5.1: SWA and tail tokens are state-space-like — they
depend only on the current position, not on a paged history. No
block table; a flat [max_requests, ...] tensor.
CSA b-stream lifecycle (paper eq.11-12):
After a CSA flush, the current a-stream (tail_ka/tail_za) becomes
the next flush's b-stream input (tail_kb/tail_zb). Both are sized
at m entries, not m-1. On first flush, tail_zb is filled with -1e9
so the softmax in the compressor naturally masks out the b-stream
(exp(-inf) = 0).
"""
from __future__ import annotations
import torch
from dsv4.cache.schema import LayerCacheSchema, AttentionType
class StateCachePool:
"""Per-layer state cache (SWA window + uncompressed tail).
Storage layout per slot:
swa_fp8: [n_win, head_dim - rope_dim] FP8 raw KV in window
swa_rope: [n_win, rope_dim] BF16 RoPE'd half
swa_inv: [n_win] FP32 per-token inv scale
swa_pos: [n_win] int32 — absolute position
swa_head: scalar int32 — ring buffer write head
tail_ka: [m_a, head_dim] BF16 — current a-stream tokens
tail_za: [m_a, head_dim] BF16 — current a-stream Z weights
tail_kb: [m_b, head_dim] BF16 — previous a-stream kept as b-input (CSA only)
tail_zb: [m_b, head_dim] BF16 — previous Z b-stream (CSA only, init to -1e9)
tail_len: scalar int32 — how many entries in a-stream are valid
"""
def __init__(
self,
schema: LayerCacheSchema,
max_requests: int,
device: str = "cuda",
):
self.schema = schema
self.max_requests = max_requests
self.device = device
mr = max_requests
nw = schema.swa_window_size
hd = schema.entry_head_dim
rd = schema.rope_dim
fp8 = hd - rd
# SWA window — circular within each slot.
self.swa_fp8 = torch.zeros((mr, nw, fp8), dtype=torch.uint8, device=device)
self.swa_rope = torch.zeros((mr, nw, rd), dtype=torch.bfloat16, device=device)
self.swa_inv = torch.ones((mr, nw), dtype=torch.float32, device=device)
self.swa_pos = torch.full((mr, nw), -1, dtype=torch.int32, device=device)
self.swa_head = torch.zeros((mr,), dtype=torch.int32, device=device)
# Tail buffer — only for compressed layers.
m_a = schema.tail_buffer_size_a # m (CSA) or m' (HCA)
m_b = schema.tail_buffer_size_b # m (CSA only)
if m_a > 0:
self.tail_ka = torch.zeros((mr, m_a, hd), dtype=torch.bfloat16, device=device)
self.tail_za = torch.zeros((mr, m_a, hd), dtype=torch.bfloat16, device=device)
self.tail_len = torch.zeros((mr,), dtype=torch.int32, device=device)
if m_b > 0: # CSA: need b-stream
self.tail_kb = torch.zeros((mr, m_b, hd), dtype=torch.bfloat16, device=device)
# Paper §3.5.1: Z^b padded with -inf at first flush.
# Init to -1e9 so softmax naturally masks b-stream on first flush.
self.tail_zb = torch.full((mr, m_b, hd), -1e9, dtype=torch.bfloat16, device=device)
else:
self.tail_kb = None
self.tail_zb = None
else:
self.tail_ka = self.tail_za = None
self.tail_kb = self.tail_zb = None
self.tail_len = None
def reset_slot(self, slot: int) -> None:
"""Clear a request's state after completion."""
self.swa_pos[slot].fill_(-1)
self.swa_head[slot] = 0
if self.tail_len is not None:
self.tail_len[slot] = 0
# Re-init tail_zb to -1e9 for CSA (paper §3.5.1 first-flush mask)
if self.tail_zb is not None:
self.tail_zb[slot].fill_(-1e9)
def memory_bytes(self) -> int:
"""Total GPU memory used by this pool."""
total = 0
for name in ("swa_fp8", "swa_rope", "swa_inv", "swa_pos", "swa_head",
"tail_ka", "tail_za", "tail_kb", "tail_zb", "tail_len"):
t = getattr(self, name)
if t is not None:
total += t.numel() * t.element_size()
return total