KV Cache: schema, allocator, pools, manager, append_swa kernel
Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
This commit is contained in:
56
dsv4/cache/allocator.py
vendored
Normal file
56
dsv4/cache/allocator.py
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Fixed-size block allocator for the classical paged KV cache.
|
||||
|
||||
One BlockAllocator per layer per "pool kind" (classical / indexer).
|
||||
Total blocks are sized at engine startup. Blocks are recycled on
|
||||
request completion.
|
||||
|
||||
Cudagraph-safety: allocation can't happen inside a captured graph
|
||||
(allocation rate is per-request not per-token). The contract is:
|
||||
- acquire() called between graph captures.
|
||||
- release() called between graph captures.
|
||||
- read access (via block table) happens INSIDE captured graphs.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
|
||||
class BlockAllocator:
|
||||
def __init__(
|
||||
self,
|
||||
num_total_blocks: int,
|
||||
device: str = "cuda",
|
||||
):
|
||||
self.num_total_blocks = num_total_blocks
|
||||
self.device = device
|
||||
|
||||
# Free-list as a GPU stack: ids[0..top-1] holds free block IDs.
|
||||
# `top` lives in pinned host memory so we can read it without a
|
||||
# device sync (it's modified only between graph captures).
|
||||
self.free_ids = torch.arange(
|
||||
num_total_blocks, dtype=torch.int32, device=device,
|
||||
)
|
||||
self.top_cpu = torch.tensor([num_total_blocks], dtype=torch.int32, pin_memory=True)
|
||||
|
||||
@property
|
||||
def num_free(self) -> int:
|
||||
return int(self.top_cpu[0])
|
||||
|
||||
def acquire(self, n: int) -> torch.Tensor:
|
||||
"""Return a tensor of `n` block IDs. Called between captures."""
|
||||
top = int(self.top_cpu[0])
|
||||
if n > top:
|
||||
raise RuntimeError(
|
||||
f"KV cache OOM: requested {n} blocks, {top} available "
|
||||
f"(of {self.num_total_blocks} total)"
|
||||
)
|
||||
new_top = top - n
|
||||
ids = self.free_ids[new_top:top].clone() # snapshot
|
||||
self.top_cpu[0] = new_top
|
||||
return ids
|
||||
|
||||
def release(self, ids: torch.Tensor) -> None:
|
||||
"""Return blocks to the free list. Called between captures."""
|
||||
n = ids.numel()
|
||||
top = int(self.top_cpu[0])
|
||||
self.free_ids[top:top + n] = ids.to(device=self.device)
|
||||
self.top_cpu[0] = top + n
|
||||
142
dsv4/cache/handle.py
vendored
Normal file
142
dsv4/cache/handle.py
vendored
Normal file
@@ -0,0 +1,142 @@
|
||||
"""LayerCacheHandle — typed per-call view onto one layer's cache.
|
||||
|
||||
Constructed by KVCacheManager.acquire() once per layer per forward.
|
||||
Holds tensor references and integer indices; no allocation. Methods
|
||||
expose the operations AttentionSubBlock needs without exposing the
|
||||
underlying storage layout.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dsv4.cache.paged_cache import PagedKVPool
|
||||
from dsv4.cache.state_cache import StateCachePool
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerCacheHandle:
|
||||
"""Read/write interface for one layer's cache.
|
||||
|
||||
The fields are the resolved indices and tensor refs for THIS call's
|
||||
batch of requests. AttentionSubBlock never sees raw pool tensors.
|
||||
"""
|
||||
# Pool references (shared across handles — never mutated).
|
||||
paged: Optional["PagedKVPool"]
|
||||
state: "StateCachePool"
|
||||
|
||||
# Per-call indices.
|
||||
request_slots: torch.Tensor # [batch] int32 — state cache slot per request
|
||||
positions: torch.Tensor # [tokens] int32 — absolute position per token
|
||||
request_ids: torch.Tensor # [tokens] int32 — which request each token belongs to
|
||||
|
||||
# Block table for the classical pool (None for SWA-only layers).
|
||||
# Shape: [batch, max_logical_blocks] int32. -1 padding for unused entries.
|
||||
block_table: Optional[torch.Tensor]
|
||||
# Number of valid blocks per request (excludes padding).
|
||||
block_lens: Optional[torch.Tensor]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Methods called by AttentionSubBlock
|
||||
# ------------------------------------------------------------------
|
||||
def write_swa(
|
||||
self,
|
||||
raw_kv: torch.Tensor, # (T, head_dim) BF16
|
||||
) -> None:
|
||||
"""Write raw KV into the SWA ring buffer AND tail compression buffer.
|
||||
|
||||
Both regions get the same tokens — SWA consumes the last n_win,
|
||||
the tail accumulates until it can flush.
|
||||
"""
|
||||
from dsv4.kernels.cache.append_swa import append_swa_kernel
|
||||
append_swa_kernel(
|
||||
raw_kv=raw_kv,
|
||||
request_slots=self.request_slots,
|
||||
positions=self.positions,
|
||||
swa_fp8=self.state.swa_fp8,
|
||||
swa_rope=self.state.swa_rope,
|
||||
swa_inv=self.state.swa_inv,
|
||||
swa_pos=self.state.swa_pos,
|
||||
swa_head=self.state.swa_head,
|
||||
rope_dim=self.state.schema.rope_dim,
|
||||
)
|
||||
|
||||
def flush_compression(
|
||||
self,
|
||||
compressed: torch.Tensor, # (T_flush, head_dim) BF16 — newly produced
|
||||
indexer_keys: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
"""Promote pending tail tokens into the classical pool.
|
||||
|
||||
Called by the compressor when the tail buffer has enough tokens.
|
||||
Allocates a new block if the latest block is full.
|
||||
|
||||
Block allocation requires going outside the captured graph — in
|
||||
a fully-captured decode this is rare (once per m or m' tokens),
|
||||
so we make it explicit. The manager has the contract.
|
||||
"""
|
||||
raise NotImplementedError("see kernels/cache/flush_compression.py")
|
||||
|
||||
def read_swa_view(self) -> "SWAView":
|
||||
"""Return a typed view of the SWA window for this batch."""
|
||||
return SWAView(
|
||||
fp8=self.state.swa_fp8,
|
||||
rope=self.state.swa_rope,
|
||||
inv_scale=self.state.swa_inv,
|
||||
positions=self.state.swa_pos,
|
||||
head=self.state.swa_head,
|
||||
slots=self.request_slots,
|
||||
)
|
||||
|
||||
def read_classical_view(self) -> "ClassicalView":
|
||||
"""Return a typed view of compressed entries for this batch."""
|
||||
assert self.paged is not None, "SWA-only layers have no classical cache"
|
||||
return ClassicalView(
|
||||
entries_fp8=self.paged.entries_fp8,
|
||||
entries_rope=self.paged.entries_rope,
|
||||
inv_scale=self.paged.inv_scale,
|
||||
block_table=self.block_table,
|
||||
block_lens=self.block_lens,
|
||||
)
|
||||
|
||||
def read_indexer_view(self) -> "IndexerView":
|
||||
"""CSA-only. Returns FP4 indexer keys with their scales."""
|
||||
assert self.paged is not None and self.paged.indexer_keys_fp4 is not None
|
||||
return IndexerView(
|
||||
keys_fp4=self.paged.indexer_keys_fp4,
|
||||
scale=self.paged.indexer_scale,
|
||||
global_scale=self.paged.indexer_global_scale,
|
||||
block_table=self.block_table,
|
||||
block_lens=self.block_lens,
|
||||
)
|
||||
|
||||
|
||||
# Typed views — simple dataclasses, no logic. The FMHA / indexer / SWA
|
||||
# kernels accept these to keep their signatures clean.
|
||||
@dataclass
|
||||
class SWAView:
|
||||
fp8: torch.Tensor
|
||||
rope: torch.Tensor
|
||||
inv_scale: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
head: torch.Tensor
|
||||
slots: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassicalView:
|
||||
entries_fp8: torch.Tensor
|
||||
entries_rope: torch.Tensor
|
||||
inv_scale: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
block_lens: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexerView:
|
||||
keys_fp4: torch.Tensor
|
||||
scale: torch.Tensor
|
||||
global_scale: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
block_lens: torch.Tensor
|
||||
200
dsv4/cache/manager.py
vendored
Normal file
200
dsv4/cache/manager.py
vendored
Normal file
@@ -0,0 +1,200 @@
|
||||
"""KVCacheManager — owns all KV cache state for one model instance.
|
||||
|
||||
Responsibilities:
|
||||
- Build per-layer pools and allocators at startup.
|
||||
- Hand out state-cache slots when requests are admitted.
|
||||
- Hand out classical blocks when layers need to flush compression.
|
||||
- Compose LayerCacheHandle for each layer per forward call.
|
||||
- Reclaim slots and blocks on request completion.
|
||||
|
||||
Not on the manager:
|
||||
- On-disk prefix storage. (Paper §3.5.2 — deferred entirely.)
|
||||
- Eviction policies. (Single-instance; requests run to completion.)
|
||||
- Cross-instance coordination.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import List, Optional, Dict
|
||||
import torch
|
||||
|
||||
from dsv4.model.config import DSV4Config
|
||||
from dsv4.model.layer_schedule import LayerSpec, AttentionType
|
||||
from dsv4.cache.schema import LayerCacheSchema, build_schema, compute_block_budget
|
||||
from dsv4.cache.allocator import BlockAllocator
|
||||
from dsv4.cache.paged_cache import PagedKVPool
|
||||
from dsv4.cache.state_cache import StateCachePool
|
||||
from dsv4.cache.handle import LayerCacheHandle
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
config: DSV4Config,
|
||||
schedule: List[LayerSpec],
|
||||
max_concurrent_requests: int,
|
||||
max_context_tokens: int = 1_000_000,
|
||||
# Per-layer-type block budget. If None, computed from
|
||||
# max_context_tokens and max_concurrent_requests.
|
||||
num_blocks_per_csa_layer: Optional[int] = None,
|
||||
num_blocks_per_hca_layer: Optional[int] = None,
|
||||
device: str = "cuda",
|
||||
):
|
||||
self.config = config
|
||||
self.schedule = schedule
|
||||
self.max_concurrent_requests = max_concurrent_requests
|
||||
self.device = device
|
||||
|
||||
# ---- Per-layer schemas ----
|
||||
self.schemas: Dict[int, LayerCacheSchema] = {
|
||||
spec.layer_idx: build_schema(config, spec) for spec in schedule
|
||||
}
|
||||
|
||||
# ---- Compute block budgets if not provided ----
|
||||
if num_blocks_per_csa_layer is None or num_blocks_per_hca_layer is None:
|
||||
budget = compute_block_budget(config, schedule, max_context_tokens,
|
||||
max_concurrent_requests)
|
||||
num_blocks_per_csa_layer = num_blocks_per_csa_layer or budget.get("csa", 0)
|
||||
num_blocks_per_hca_layer = num_blocks_per_hca_layer or budget.get("hca", 0)
|
||||
|
||||
# ---- Per-layer pools ----
|
||||
# State cache exists for every layer.
|
||||
self.state_pools: Dict[int, StateCachePool] = {
|
||||
i: StateCachePool(schema, max_concurrent_requests, device)
|
||||
for i, schema in self.schemas.items()
|
||||
}
|
||||
|
||||
# Classical paged pool only for compressed layers.
|
||||
self.paged_pools: Dict[int, Optional[PagedKVPool]] = {}
|
||||
self.allocators: Dict[int, Optional[BlockAllocator]] = {}
|
||||
for i, schema in self.schemas.items():
|
||||
if schema.entries_per_block == 0:
|
||||
self.paged_pools[i] = None
|
||||
self.allocators[i] = None
|
||||
else:
|
||||
nb = (num_blocks_per_csa_layer
|
||||
if schema.attn_type == AttentionType.CSA
|
||||
else num_blocks_per_hca_layer)
|
||||
self.paged_pools[i] = PagedKVPool(schema, nb, device)
|
||||
self.allocators[i] = BlockAllocator(nb, device)
|
||||
|
||||
# ---- Request state ----
|
||||
# Slot index per request, into state cache pools (same index in
|
||||
# every layer). -1 = slot free.
|
||||
self.request_slot_map: torch.Tensor = torch.full(
|
||||
(max_concurrent_requests,), -1, dtype=torch.int32, device=device,
|
||||
)
|
||||
|
||||
# Block table per request per layer:
|
||||
# block_tables[layer_idx][request_slot, logical_block_idx]
|
||||
# -> physical_block_idx
|
||||
max_blocks = max_context_tokens // 128 # BLOCK_SIZE_ORIGINAL_TOKENS
|
||||
self.max_blocks_per_request = max_blocks
|
||||
self.block_tables: Dict[int, torch.Tensor] = {}
|
||||
self.block_lens: Dict[int, torch.Tensor] = {}
|
||||
for i, schema in self.schemas.items():
|
||||
if schema.entries_per_block > 0:
|
||||
self.block_tables[i] = torch.full(
|
||||
(max_concurrent_requests, max_blocks), -1,
|
||||
dtype=torch.int32, device=device,
|
||||
)
|
||||
self.block_lens[i] = torch.zeros(
|
||||
(max_concurrent_requests,), dtype=torch.int32, device=device,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request lifecycle (called between captured graphs)
|
||||
# ------------------------------------------------------------------
|
||||
def admit_request(self) -> int:
|
||||
"""Allocate a state cache slot. Returns the slot index."""
|
||||
free = (self.request_slot_map == -1).nonzero(as_tuple=False)
|
||||
if free.numel() == 0:
|
||||
raise RuntimeError("max concurrent requests exceeded")
|
||||
slot = int(free[0])
|
||||
self.request_slot_map[slot] = slot
|
||||
return slot
|
||||
|
||||
def release_request(self, slot: int) -> None:
|
||||
"""Return state cache slot and all associated blocks to free lists."""
|
||||
for layer_idx, alloc in self.allocators.items():
|
||||
if alloc is None:
|
||||
continue
|
||||
table = self.block_tables[layer_idx]
|
||||
lens = self.block_lens[layer_idx]
|
||||
valid = int(lens[slot])
|
||||
if valid > 0:
|
||||
alloc.release(table[slot, :valid].clone())
|
||||
lens[slot] = 0
|
||||
table[slot].fill_(-1)
|
||||
# Reset state cache slot.
|
||||
for state in self.state_pools.values():
|
||||
state.reset_slot(slot)
|
||||
self.request_slot_map[slot] = -1
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Block allocation for compression flush (called between captures)
|
||||
# ------------------------------------------------------------------
|
||||
def allocate_block(self, layer_idx: int, request_slot: int) -> int:
|
||||
"""Allocate one new classical block for a request. Returns block ID."""
|
||||
alloc = self.allocators[layer_idx]
|
||||
assert alloc is not None, f"layer {layer_idx} has no classical pool"
|
||||
block_id = alloc.acquire(1)
|
||||
bid = int(block_id[0])
|
||||
# Append to the request's block table.
|
||||
table = self.block_tables[layer_idx]
|
||||
lens = self.block_lens[layer_idx]
|
||||
pos = int(lens[request_slot])
|
||||
assert pos < self.max_blocks_per_request, "block table overflow"
|
||||
table[request_slot, pos] = bid
|
||||
lens[request_slot] = pos + 1
|
||||
return bid
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-forward handle construction (called INSIDE captured graph)
|
||||
# ------------------------------------------------------------------
|
||||
def acquire(
|
||||
self,
|
||||
layer_idx: int,
|
||||
request_slots: torch.Tensor, # [batch] int32
|
||||
positions: torch.Tensor, # [tokens] int32
|
||||
request_ids: torch.Tensor, # [tokens] int32
|
||||
) -> LayerCacheHandle:
|
||||
"""Build the LayerCacheHandle for one layer's forward.
|
||||
|
||||
No allocation happens here — critical for cudagraph safety.
|
||||
"""
|
||||
paged = self.paged_pools[layer_idx]
|
||||
state = self.state_pools[layer_idx]
|
||||
|
||||
if paged is not None:
|
||||
# Pass the full tensors — no indexing, no allocation.
|
||||
# The attention kernel indexes by request_slots internally.
|
||||
block_table = self.block_tables[layer_idx]
|
||||
block_lens = self.block_lens[layer_idx]
|
||||
else:
|
||||
block_table = None
|
||||
block_lens = None
|
||||
|
||||
return LayerCacheHandle(
|
||||
paged=paged,
|
||||
state=state,
|
||||
request_slots=request_slots,
|
||||
positions=positions,
|
||||
request_ids=request_ids,
|
||||
block_table=block_table,
|
||||
block_lens=block_lens,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Diagnostics
|
||||
# ------------------------------------------------------------------
|
||||
def memory_bytes(self) -> int:
|
||||
"""Total GPU memory used by all pools."""
|
||||
total = 0
|
||||
for pool in self.state_pools.values():
|
||||
total += pool.memory_bytes()
|
||||
for pool in self.paged_pools.values():
|
||||
if pool is not None:
|
||||
total += pool.memory_bytes()
|
||||
for i, table in self.block_tables.items():
|
||||
total += table.numel() * table.element_size()
|
||||
total += self.block_lens[i].numel() * self.block_lens[i].element_size()
|
||||
return total
|
||||
93
dsv4/cache/paged_cache.py
vendored
93
dsv4/cache/paged_cache.py
vendored
@@ -1,2 +1,91 @@
|
||||
"""Paged KV cache."""
|
||||
# TODO: Phase 3
|
||||
"""Storage for one layer's classical paged KV cache.
|
||||
|
||||
Layout per block:
|
||||
entries: [num_blocks, entries_per_block, head_dim - rope_dim] FP8 (uint8 view)
|
||||
entries_r: [num_blocks, entries_per_block, rope_dim] BF16
|
||||
inv_scale: [num_blocks, entries_per_block] FP32
|
||||
|
||||
The FP8/BF16 split mirrors paper §2.3.4 ("BF16 for RoPE dims, FP8 for
|
||||
the rest"). The kernel reads both halves and concatenates in registers.
|
||||
|
||||
For CSA layers, a parallel pool stores indexer keys at the same block
|
||||
granularity — same block ID maps to a block in both pools.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
from dsv4.cache.schema import LayerCacheSchema
|
||||
|
||||
|
||||
class PagedKVPool:
|
||||
"""Per-layer classical paged KV storage.
|
||||
|
||||
Indexed by [physical_block_id, slot_in_block, ...].
|
||||
Both compressed entries and indexer keys (if applicable) are
|
||||
indexed by the SAME physical_block_id so a CSA layer's two pools
|
||||
share the block table.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: LayerCacheSchema,
|
||||
num_blocks: int,
|
||||
device: str = "cuda",
|
||||
):
|
||||
self.schema = schema
|
||||
self.num_blocks = num_blocks
|
||||
self.device = device
|
||||
|
||||
nb = num_blocks
|
||||
epb = schema.entries_per_block
|
||||
hd = schema.entry_head_dim
|
||||
rd = schema.rope_dim
|
||||
fp8_dim = hd - rd
|
||||
|
||||
# ---- Compressed entries ----
|
||||
# FP8 stored as uint8 (we view as float8_e4m3fn at read time).
|
||||
self.entries_fp8 = torch.zeros(
|
||||
(nb, epb, fp8_dim), dtype=torch.uint8, device=device,
|
||||
)
|
||||
# BF16 RoPE'd half — no quantization.
|
||||
self.entries_rope = torch.zeros(
|
||||
(nb, epb, rd), dtype=torch.bfloat16, device=device,
|
||||
)
|
||||
# Per-entry inverse scale (for FP8 dequant in attention kernel).
|
||||
self.inv_scale = torch.ones(
|
||||
(nb, epb), dtype=torch.float32, device=device,
|
||||
)
|
||||
|
||||
# ---- Indexer keys (CSA only) ----
|
||||
if schema.indexer_entries_per_block > 0:
|
||||
i_epb = schema.indexer_entries_per_block
|
||||
i_hd = schema.indexer_head_dim
|
||||
# Indexer QK is FP4 per paper §2.3.4 — but we store the keys
|
||||
# post-quant. uint8 = 2 FP4 packed per byte.
|
||||
self.indexer_keys_fp4 = torch.zeros(
|
||||
(nb, i_epb, i_hd // 2), dtype=torch.uint8, device=device,
|
||||
)
|
||||
# Per-block-vector scale for the FP4 (one E4M3 scalar per
|
||||
# 16-element group, per the NVFP4 quantization scheme).
|
||||
self.indexer_scale = torch.ones(
|
||||
(nb, i_epb, i_hd // 16),
|
||||
dtype=torch.float8_e4m3fn, device=device,
|
||||
)
|
||||
self.indexer_global_scale = torch.ones(
|
||||
(nb,), dtype=torch.float32, device=device,
|
||||
)
|
||||
else:
|
||||
self.indexer_keys_fp4 = None
|
||||
self.indexer_scale = None
|
||||
self.indexer_global_scale = None
|
||||
|
||||
def memory_bytes(self) -> int:
|
||||
"""Total GPU memory used by this pool."""
|
||||
total = 0
|
||||
for name in ("entries_fp8", "entries_rope", "inv_scale",
|
||||
"indexer_keys_fp4", "indexer_scale", "indexer_global_scale"):
|
||||
t = getattr(self, name)
|
||||
if t is not None:
|
||||
total += t.numel() * t.element_size()
|
||||
return total
|
||||
|
||||
129
dsv4/cache/schema.py
vendored
Normal file
129
dsv4/cache/schema.py
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Per-layer KV cache shape.
|
||||
|
||||
Computed once per layer at engine startup from the LayerSpec. The
|
||||
schema is what tells the allocator how big each pool slot is and what
|
||||
sub-regions exist (compressed entries / indexer keys / SWA window /
|
||||
uncompressed tail).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from dsv4.model.config import DSV4Config
|
||||
from dsv4.model.layer_schedule import LayerSpec, AttentionType
|
||||
|
||||
|
||||
# Block size is invariant for DSV4 — derived from compression ratios.
|
||||
# lcm(m, m') = lcm(4, 128) = 128 original tokens per block.
|
||||
# Holds 128/4 = 32 CSA entries OR 128/128 = 1 HCA entry per block.
|
||||
BLOCK_SIZE_ORIGINAL_TOKENS = 128
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LayerCacheSchema:
|
||||
"""Cache layout for one transformer layer.
|
||||
|
||||
Fields with `_per_block` are the dimensions of one block in the
|
||||
classical paged pool. `_per_state_slot` are dimensions of one
|
||||
request's slot in the state cache.
|
||||
|
||||
All sizes are in number of entries — bytes come from the dtypes.
|
||||
"""
|
||||
layer_idx: int
|
||||
attn_type: AttentionType
|
||||
|
||||
# ---- Classical paged cache (compressed entries) ----
|
||||
# Number of compressed entries in one block of BLOCK_SIZE_ORIGINAL_TOKENS
|
||||
# original tokens. For HCA m'=128 this is 1; for CSA m=4 this is 32.
|
||||
# SWA-only layers have no classical pool — entries_per_block = 0.
|
||||
entries_per_block: int
|
||||
# Width of one entry (head_dim).
|
||||
entry_head_dim: int
|
||||
# RoPE-applied dimensions (BF16). Others FP8.
|
||||
rope_dim: int
|
||||
|
||||
# ---- Indexer pool (CSA only) ----
|
||||
# Compressed indexer keys, one per compressed entry.
|
||||
indexer_entries_per_block: int # 32 for CSA, 0 for HCA/SWA
|
||||
indexer_head_dim: int # 128 for CSA, 0 for others
|
||||
|
||||
# ---- State cache (SWA window + uncompressed tail) ----
|
||||
swa_window_size: int # 128 for all layer types
|
||||
# Uncompressed tail buffer — needed only for compressed layers.
|
||||
# CSA: up to m-1 = 3 pending tokens before flushing compression.
|
||||
# HCA: up to m'-1 = 127 pending tokens.
|
||||
# SWA-only: no tail (no compression branch).
|
||||
tail_buffer_size: int
|
||||
|
||||
# Per-token inverse scale storage (for FP8 dequant). One FP32 scalar
|
||||
# per stored entry/window-slot.
|
||||
needs_inv_scale: bool = True
|
||||
|
||||
|
||||
def build_schema(config: DSV4Config, spec: LayerSpec) -> LayerCacheSchema:
|
||||
"""Derive cache schema for a single layer from architectural config."""
|
||||
if spec.attn == AttentionType.CSA:
|
||||
return LayerCacheSchema(
|
||||
layer_idx=spec.layer_idx,
|
||||
attn_type=AttentionType.CSA,
|
||||
entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio,
|
||||
entry_head_dim=config.head_dim,
|
||||
rope_dim=config.rope_dim,
|
||||
indexer_entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.csa_compression_ratio,
|
||||
indexer_head_dim=config.indexer_head_dim,
|
||||
swa_window_size=config.sliding_window,
|
||||
tail_buffer_size=config.csa_compression_ratio - 1,
|
||||
)
|
||||
elif spec.attn == AttentionType.HCA:
|
||||
return LayerCacheSchema(
|
||||
layer_idx=spec.layer_idx,
|
||||
attn_type=AttentionType.HCA,
|
||||
entries_per_block=BLOCK_SIZE_ORIGINAL_TOKENS // config.hca_compression_ratio,
|
||||
entry_head_dim=config.head_dim,
|
||||
rope_dim=config.rope_dim,
|
||||
indexer_entries_per_block=0,
|
||||
indexer_head_dim=0,
|
||||
swa_window_size=config.sliding_window,
|
||||
tail_buffer_size=config.hca_compression_ratio - 1,
|
||||
)
|
||||
else: # SWA-only
|
||||
return LayerCacheSchema(
|
||||
layer_idx=spec.layer_idx,
|
||||
attn_type=AttentionType.SWA,
|
||||
entries_per_block=0,
|
||||
entry_head_dim=config.head_dim,
|
||||
rope_dim=config.rope_dim,
|
||||
indexer_entries_per_block=0,
|
||||
indexer_head_dim=0,
|
||||
swa_window_size=config.sliding_window,
|
||||
tail_buffer_size=0,
|
||||
)
|
||||
|
||||
|
||||
def compute_block_budget(
|
||||
config: DSV4Config,
|
||||
schedule: list[LayerSpec],
|
||||
max_context_tokens: int,
|
||||
max_concurrent_requests: int,
|
||||
) -> dict[str, int]:
|
||||
"""Compute per-layer-type block counts for the allocator.
|
||||
|
||||
Returns {layer_type: num_blocks} where layer_type is 'csa' or 'hca'.
|
||||
SWA-only layers need no classical blocks.
|
||||
|
||||
Block budget = max_concurrent_requests * (max_context / BLOCK_SIZE).
|
||||
Add 10% headroom for fragmentation.
|
||||
"""
|
||||
blocks_per_request = max_context_tokens // BLOCK_SIZE_ORIGINAL_TOKENS
|
||||
headroom = 1.10
|
||||
result = {}
|
||||
for spec in schedule:
|
||||
if spec.attn == AttentionType.CSA:
|
||||
key = "csa"
|
||||
elif spec.attn == AttentionType.HCA:
|
||||
key = "hca"
|
||||
else:
|
||||
continue
|
||||
total = int(max_concurrent_requests * blocks_per_request * headroom)
|
||||
result[key] = max(result.get(key, 0), total)
|
||||
return result
|
||||
98
dsv4/cache/state_cache.py
vendored
98
dsv4/cache/state_cache.py
vendored
@@ -1,2 +1,96 @@
|
||||
"""State cache for KV."""
|
||||
# TODO: Phase 3
|
||||
"""State cache: SWA window + uncompressed tail buffer.
|
||||
|
||||
One slot per active request. Slot index is fixed for a request's
|
||||
lifetime — the manager hands out slot indices at request admission
|
||||
and reclaims them at completion.
|
||||
|
||||
Per paper §3.5.1: SWA and tail tokens are state-space-like — they
|
||||
depend only on the current position, not on a paged history. No
|
||||
block table; a flat [max_requests, ...] tensor.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
from dsv4.cache.schema import LayerCacheSchema, AttentionType
|
||||
|
||||
|
||||
class StateCachePool:
|
||||
"""Per-layer state cache (SWA window + uncompressed tail).
|
||||
|
||||
Storage layout per slot:
|
||||
swa_fp8: [n_win, head_dim - rope_dim] FP8 raw KV in window
|
||||
swa_rope: [n_win, rope_dim] BF16 RoPE'd half
|
||||
swa_inv: [n_win] FP32 per-token inv scale
|
||||
swa_pos: [n_win] int32 — absolute position
|
||||
of each window slot (-1 if invalid)
|
||||
|
||||
tail_ka: [tail_size, head_dim] BF16 raw — pending tokens
|
||||
not yet compressed
|
||||
tail_za: [tail_size, head_dim] BF16 — compression weights
|
||||
(Z stream for CSA, single Z for HCA)
|
||||
tail_kb: [tail_size, head_dim] BF16 — second stream (CSA only)
|
||||
tail_zb: [tail_size, head_dim] BF16 — second Z stream (CSA only)
|
||||
tail_len: scalar int32 — how many tail entries are valid
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: LayerCacheSchema,
|
||||
max_requests: int,
|
||||
device: str = "cuda",
|
||||
):
|
||||
self.schema = schema
|
||||
self.max_requests = max_requests
|
||||
self.device = device
|
||||
|
||||
mr = max_requests
|
||||
nw = schema.swa_window_size
|
||||
hd = schema.entry_head_dim
|
||||
rd = schema.rope_dim
|
||||
fp8 = hd - rd
|
||||
|
||||
# SWA window — circular within each slot. Layer's attention
|
||||
# kernel uses swa_pos to mask invalid entries.
|
||||
self.swa_fp8 = torch.zeros((mr, nw, fp8), dtype=torch.uint8, device=device)
|
||||
self.swa_rope = torch.zeros((mr, nw, rd), dtype=torch.bfloat16, device=device)
|
||||
self.swa_inv = torch.ones((mr, nw), dtype=torch.float32, device=device)
|
||||
self.swa_pos = torch.full((mr, nw), -1, dtype=torch.int32, device=device)
|
||||
# Next write position within each slot's ring buffer.
|
||||
self.swa_head = torch.zeros((mr,), dtype=torch.int32, device=device)
|
||||
|
||||
# Tail buffer — only non-empty for compressed layers.
|
||||
tail = schema.tail_buffer_size
|
||||
if tail > 0:
|
||||
# For CSA we need two streams (Ca/Cb, Za/Zb) since the
|
||||
# compressor uses overlapping pairs. HCA only needs one
|
||||
# stream. Store both; HCA leaves the b-channel zero.
|
||||
self.tail_ka = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
||||
self.tail_za = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
||||
if schema.attn_type == AttentionType.CSA:
|
||||
self.tail_kb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
||||
self.tail_zb = torch.zeros((mr, tail, hd), dtype=torch.bfloat16, device=device)
|
||||
else:
|
||||
self.tail_kb = None
|
||||
self.tail_zb = None
|
||||
self.tail_len = torch.zeros((mr,), dtype=torch.int32, device=device)
|
||||
else:
|
||||
self.tail_ka = self.tail_kb = None
|
||||
self.tail_za = self.tail_zb = None
|
||||
self.tail_len = None
|
||||
|
||||
def reset_slot(self, slot: int) -> None:
|
||||
"""Clear a request's state after completion."""
|
||||
self.swa_pos[slot].fill_(-1)
|
||||
self.swa_head[slot] = 0
|
||||
if self.tail_len is not None:
|
||||
self.tail_len[slot] = 0
|
||||
|
||||
def memory_bytes(self) -> int:
|
||||
"""Total GPU memory used by this pool."""
|
||||
total = 0
|
||||
for name in ("swa_fp8", "swa_rope", "swa_inv", "swa_pos", "swa_head",
|
||||
"tail_ka", "tail_za", "tail_kb", "tail_zb", "tail_len"):
|
||||
t = getattr(self, name)
|
||||
if t is not None:
|
||||
total += t.numel() * t.element_size()
|
||||
return total
|
||||
|
||||
1
dsv4/kernels/cache/__init__.py
vendored
Normal file
1
dsv4/kernels/cache/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
51
dsv4/kernels/cache/append_swa.py
vendored
Normal file
51
dsv4/kernels/cache/append_swa.py
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Python wrapper for the append_swa CUDA kernel.
|
||||
|
||||
Writes raw BF16 KV into the FP8/BF16 split state cache layout.
|
||||
Quantizes the non-RoPE half BF16 -> FP8 (E4M3 amax-based scaling),
|
||||
writes the RoPE half as-is, computes per-token inverse scale, and
|
||||
updates the ring buffer head + position field.
|
||||
|
||||
One block per token. Threads cooperatively:
|
||||
1. Compute amax over fp8-dim elements (warp reduce).
|
||||
2. Quantize BF16 -> FP8 with per-token scale.
|
||||
3. Write FP8 entries + BF16 RoPE entries + inv_scale + position.
|
||||
4. Atomic increment ring buffer head.
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_kernel_module = None
|
||||
|
||||
|
||||
def _get_kernel_module():
|
||||
global _kernel_module
|
||||
if _kernel_module is not None:
|
||||
return _kernel_module
|
||||
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda")
|
||||
_kernel_module = load(
|
||||
name="append_swa",
|
||||
sources=[os.path.join(kernel_dir, "append_swa.cu")],
|
||||
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
|
||||
verbose=False,
|
||||
)
|
||||
return _kernel_module
|
||||
|
||||
|
||||
def append_swa_kernel(
|
||||
raw_kv: torch.Tensor, # (T, head_dim) BF16
|
||||
request_slots: torch.Tensor, # (T,) int32
|
||||
positions: torch.Tensor, # (T,) int32
|
||||
swa_fp8: torch.Tensor, # (max_req, n_win, fp8_dim) uint8
|
||||
swa_rope: torch.Tensor, # (max_req, n_win, rope_dim) BF16
|
||||
swa_inv: torch.Tensor, # (max_req, n_win) FP32
|
||||
swa_pos: torch.Tensor, # (max_req, n_win) int32
|
||||
swa_head: torch.Tensor, # (max_req,) int32
|
||||
rope_dim: int,
|
||||
):
|
||||
mod = _get_kernel_module()
|
||||
mod.append_swa(
|
||||
raw_kv, request_slots, positions,
|
||||
swa_fp8, swa_rope, swa_inv, swa_pos, swa_head,
|
||||
rope_dim,
|
||||
)
|
||||
165
dsv4/kernels/cuda/append_swa.cu
Normal file
165
dsv4/kernels/cuda/append_swa.cu
Normal file
@@ -0,0 +1,165 @@
|
||||
// append_swa.cu — write raw BF16 KV into the SWA ring buffer.
|
||||
//
|
||||
// One block per token. Threads cooperatively:
|
||||
// 1. Compute amax over fp8-dim elements (warp reduce).
|
||||
// 2. Quantize BF16 -> FP8 E4M3 with per-token scale.
|
||||
// 3. Write FP8 entries + BF16 RoPE entries + inv_scale + position.
|
||||
// 4. Atomic increment ring buffer head.
|
||||
//
|
||||
// Paper §2.3.4: BF16 for RoPE'd dims, FP8 for the rest.
|
||||
// Per-token inverse scale stored for dequant in the attention kernel.
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
// Warp-level amax reduction
|
||||
__device__ __forceinline__ float warp_reduce_amax(float val) {
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
float other = __shfl_down_sync(0xffffffff, val, offset);
|
||||
val = fmaxf(val, fabsf(other));
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
// Warp-level sum for counting valid entries
|
||||
__device__ __forceinline__ float warp_reduce_sum(float val) {
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
val += __shfl_down_sync(0xffffffff, val, offset);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
__global__ void append_swa_kernel(
|
||||
const __nv_bfloat16* __restrict__ raw_kv, // [T, head_dim]
|
||||
const int32_t* __restrict__ request_slots, // [T] -> slot in state pool
|
||||
const int32_t* __restrict__ positions, // [T] -> absolute position
|
||||
// State cache pool — written in place.
|
||||
uint8_t* __restrict__ swa_fp8, // [max_req, n_win, fp8_dim]
|
||||
__nv_bfloat16* __restrict__ swa_rope, // [max_req, n_win, rope_dim]
|
||||
float* __restrict__ swa_inv, // [max_req, n_win]
|
||||
int32_t* __restrict__ swa_pos, // [max_req, n_win]
|
||||
int32_t* __restrict__ swa_head, // [max_req]
|
||||
int T, int n_win, int head_dim, int rope_dim
|
||||
) {
|
||||
int t = blockIdx.x;
|
||||
if (t >= T) return;
|
||||
|
||||
int lane = threadIdx.x;
|
||||
int warp_size = blockDim.x; // expect 128 threads per block
|
||||
|
||||
int slot = request_slots[t];
|
||||
int pos = positions[t];
|
||||
int fp8_dim = head_dim - rope_dim;
|
||||
|
||||
// ---- Step 1: Compute amax over fp8_dim elements ----
|
||||
// Each thread processes strided elements of the fp8 half.
|
||||
float local_amax = 0.0f;
|
||||
for (int i = lane; i < fp8_dim; i += warp_size) {
|
||||
float val = __bfloat162float(raw_kv[t * head_dim + i]);
|
||||
local_amax = fmaxf(local_amax, fabsf(val));
|
||||
}
|
||||
|
||||
// Warp-level amax reduction (works for warp_size <= 32).
|
||||
// For 128 threads, we need to reduce across 4 warps.
|
||||
float block_amax = 0.0f;
|
||||
// Intra-warp reduce
|
||||
float warp_amax = warp_reduce_amax(local_amax);
|
||||
// Lane 0 of each warp writes to shared memory
|
||||
__shared__ float smem_amax[4]; // max 4 warps for 128 threads
|
||||
if (lane % 32 == 0) {
|
||||
smem_amax[lane / 32] = warp_amax;
|
||||
}
|
||||
__syncthreads();
|
||||
if (lane < 32) {
|
||||
float v = (lane < (warp_size + 31) / 32) ? smem_amax[lane] : 0.0f;
|
||||
block_amax = warp_reduce_amax(v);
|
||||
}
|
||||
__syncthreads();
|
||||
// Broadcast block_amax to all threads
|
||||
__shared__ float s_inv_scale;
|
||||
if (lane == 0) {
|
||||
float scale = block_amax / 448.0f; // FP8 E4M3 max = 448
|
||||
if (scale < 1e-12f) scale = 1e-12f; // avoid div-by-zero
|
||||
s_inv_scale = scale;
|
||||
}
|
||||
__syncthreads();
|
||||
float inv_scale_val = s_inv_scale;
|
||||
|
||||
// ---- Step 2: Atomic increment ring buffer head ----
|
||||
// Only one thread per block does the atomic
|
||||
__shared__ int slot_in_window;
|
||||
if (lane == 0) {
|
||||
slot_in_window = atomicAdd(&swa_head[slot], 1) % n_win;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- Step 3: Write FP8 entries ----
|
||||
for (int i = lane; i < fp8_dim; i += warp_size) {
|
||||
float val = __bfloat162float(raw_kv[t * head_dim + i]);
|
||||
float quantized = val / inv_scale_val;
|
||||
// Clamp to FP8 E4M3 range [-448, 448]
|
||||
quantized = fmaxf(-448.0f, fminf(448.0f, quantized));
|
||||
// Convert to FP8 E4M3
|
||||
__nv_fp8_e4m3 fp8_val;
|
||||
fp8_val.__x = __nv_fp8_e4m3(quantized).__x;
|
||||
swa_fp8[slot * n_win * fp8_dim + slot_in_window * fp8_dim + i] = fp8_val.__x;
|
||||
}
|
||||
|
||||
// ---- Step 4: Write BF16 RoPE entries ----
|
||||
for (int i = lane; i < rope_dim; i += warp_size) {
|
||||
__nv_bfloat16 val = raw_kv[t * head_dim + fp8_dim + i];
|
||||
swa_rope[slot * n_win * rope_dim + slot_in_window * rope_dim + i] = val;
|
||||
}
|
||||
|
||||
// ---- Step 5: Write metadata (single thread) ----
|
||||
if (lane == 0) {
|
||||
swa_inv[slot * n_win + slot_in_window] = inv_scale_val;
|
||||
swa_pos[slot * n_win + slot_in_window] = pos;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
append_swa_cuda(
|
||||
torch::Tensor raw_kv, // [T, head_dim] BF16
|
||||
torch::Tensor request_slots, // [T] int32
|
||||
torch::Tensor positions, // [T] int32
|
||||
torch::Tensor swa_fp8, // [max_req, n_win, fp8_dim] uint8
|
||||
torch::Tensor swa_rope, // [max_req, n_win, rope_dim] BF16
|
||||
torch::Tensor swa_inv, // [max_req, n_win] FP32
|
||||
torch::Tensor swa_pos, // [max_req, n_win] int32
|
||||
torch::Tensor swa_head, // [max_req] int32
|
||||
int64_t rope_dim
|
||||
) {
|
||||
int T = raw_kv.size(0);
|
||||
int head_dim = raw_kv.size(1);
|
||||
int n_win = swa_fp8.size(1);
|
||||
|
||||
int threads = 128;
|
||||
int blocks = T;
|
||||
|
||||
append_swa_kernel<<<blocks, threads>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(raw_kv.data_ptr<at::BFloat16>()),
|
||||
request_slots.data_ptr<int32_t>(),
|
||||
positions.data_ptr<int32_t>(),
|
||||
swa_fp8.data_ptr<uint8_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(swa_rope.data_ptr<at::BFloat16>()),
|
||||
swa_inv.data_ptr<float>(),
|
||||
swa_pos.data_ptr<int32_t>(),
|
||||
swa_head.data_ptr<int32_t>(),
|
||||
T, n_win, head_dim, static_cast<int>(rope_dim)
|
||||
);
|
||||
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(swa_fp8, swa_rope, swa_inv, swa_pos, swa_head);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("append_swa", &append_swa_cuda, "Append SWA kernel");
|
||||
}
|
||||
252
tests/unit/test_cache.py
Normal file
252
tests/unit/test_cache.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Tests for KV cache: schema, allocator, pools, manager lifecycle."""
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from dsv4.model.config import DSV4Config
|
||||
from dsv4.model.layer_schedule import build_schedule, AttentionType, RouterMode, LayerSpec
|
||||
from dsv4.cache.schema import build_schema, compute_block_budget, BLOCK_SIZE_ORIGINAL_TOKENS
|
||||
from dsv4.cache.allocator import BlockAllocator
|
||||
from dsv4.cache.paged_cache import PagedKVPool
|
||||
from dsv4.cache.state_cache import StateCachePool
|
||||
from dsv4.cache.manager import KVCacheManager
|
||||
|
||||
|
||||
# ---- Schema tests ----
|
||||
|
||||
def test_csa_schema():
|
||||
config = DSV4Config.pro()
|
||||
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
|
||||
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
|
||||
router_mode=RouterMode.HASH)
|
||||
schema = build_schema(config, spec)
|
||||
assert schema.entries_per_block == 32 # 128 / 4
|
||||
assert schema.indexer_entries_per_block == 32
|
||||
assert schema.tail_buffer_size == 3 # m - 1
|
||||
assert schema.swa_window_size == 128
|
||||
|
||||
|
||||
def test_hca_schema():
|
||||
config = DSV4Config.pro()
|
||||
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
|
||||
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
|
||||
router_mode=RouterMode.DENSE)
|
||||
schema = build_schema(config, spec)
|
||||
assert schema.entries_per_block == 1 # 128 / 128
|
||||
assert schema.indexer_entries_per_block == 0
|
||||
assert schema.tail_buffer_size == 127 # m' - 1
|
||||
|
||||
|
||||
def test_swa_schema():
|
||||
config = DSV4Config.flash()
|
||||
spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA,
|
||||
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
|
||||
router_mode=RouterMode.HASH)
|
||||
schema = build_schema(config, spec)
|
||||
assert schema.entries_per_block == 0
|
||||
assert schema.indexer_entries_per_block == 0
|
||||
assert schema.tail_buffer_size == 0
|
||||
assert schema.swa_window_size == 128
|
||||
|
||||
|
||||
def test_schema_from_schedule():
|
||||
"""Every layer in a full schedule produces a valid schema."""
|
||||
config = DSV4Config.flash()
|
||||
schedule = build_schedule(config)
|
||||
for spec in schedule:
|
||||
schema = build_schema(config, spec)
|
||||
assert schema.swa_window_size > 0
|
||||
assert schema.entry_head_dim == config.head_dim
|
||||
if spec.attn == AttentionType.SWA:
|
||||
assert schema.entries_per_block == 0
|
||||
else:
|
||||
assert schema.entries_per_block > 0
|
||||
|
||||
|
||||
def test_block_budget():
|
||||
config = DSV4Config.pro()
|
||||
schedule = build_schedule(config)
|
||||
budget = compute_block_budget(config, schedule, 1_000_000, 16)
|
||||
assert "csa" in budget
|
||||
assert "hca" in budget
|
||||
assert budget["csa"] > budget["hca"] # CSA uses more blocks per request
|
||||
|
||||
|
||||
# ---- Allocator tests ----
|
||||
|
||||
def test_acquire_release_roundtrip():
|
||||
alloc = BlockAllocator(num_total_blocks=1024)
|
||||
a = alloc.acquire(10)
|
||||
assert alloc.num_free == 1014
|
||||
b = alloc.acquire(5)
|
||||
assert alloc.num_free == 1009
|
||||
alloc.release(a)
|
||||
assert alloc.num_free == 1019
|
||||
alloc.release(b)
|
||||
assert alloc.num_free == 1024
|
||||
# Re-acquire works after release.
|
||||
c = alloc.acquire(20)
|
||||
assert alloc.num_free == 1004
|
||||
|
||||
|
||||
def test_oom_raises():
|
||||
alloc = BlockAllocator(num_total_blocks=4)
|
||||
alloc.acquire(4)
|
||||
with pytest.raises(RuntimeError, match="OOM"):
|
||||
alloc.acquire(1)
|
||||
|
||||
|
||||
def test_acquire_returns_unique_ids():
|
||||
alloc = BlockAllocator(num_total_blocks=100)
|
||||
a = alloc.acquire(50)
|
||||
b = alloc.acquire(50)
|
||||
assert len(torch.intersect1d(a, b)) == 0
|
||||
|
||||
|
||||
# ---- Pool shape tests ----
|
||||
|
||||
def test_paged_pool_shapes_csa():
|
||||
from dsv4.model.layer_schedule import FFNType
|
||||
config = DSV4Config.pro()
|
||||
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
|
||||
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
|
||||
schema = build_schema(config, spec)
|
||||
pool = PagedKVPool(schema, num_blocks=16)
|
||||
assert pool.entries_fp8.shape == (16, 32, config.head_dim - config.rope_dim)
|
||||
assert pool.entries_rope.shape == (16, 32, config.rope_dim)
|
||||
assert pool.inv_scale.shape == (16, 32)
|
||||
assert pool.indexer_keys_fp4 is not None
|
||||
assert pool.indexer_keys_fp4.shape[1] == 32
|
||||
|
||||
|
||||
def test_paged_pool_shapes_hca():
|
||||
from dsv4.model.layer_schedule import FFNType
|
||||
config = DSV4Config.pro()
|
||||
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
|
||||
ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
|
||||
schema = build_schema(config, spec)
|
||||
pool = PagedKVPool(schema, num_blocks=256)
|
||||
assert pool.entries_fp8.shape == (256, 1, config.head_dim - config.rope_dim)
|
||||
assert pool.entries_rope.shape == (256, 1, config.rope_dim)
|
||||
assert pool.indexer_keys_fp4 is None
|
||||
|
||||
|
||||
def test_state_pool_shapes_csa():
|
||||
from dsv4.model.layer_schedule import FFNType
|
||||
config = DSV4Config.pro()
|
||||
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
|
||||
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
|
||||
schema = build_schema(config, spec)
|
||||
pool = StateCachePool(schema, max_requests=8)
|
||||
assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim)
|
||||
assert pool.swa_rope.shape == (8, 128, config.rope_dim)
|
||||
assert pool.tail_ka is not None
|
||||
assert pool.tail_kb is not None # CSA has both streams
|
||||
assert pool.tail_len is not None
|
||||
|
||||
|
||||
def test_state_pool_shapes_hca():
|
||||
from dsv4.model.layer_schedule import FFNType
|
||||
config = DSV4Config.pro()
|
||||
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
|
||||
ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
|
||||
schema = build_schema(config, spec)
|
||||
pool = StateCachePool(schema, max_requests=8)
|
||||
assert pool.tail_ka is not None
|
||||
assert pool.tail_kb is None # HCA only one stream
|
||||
assert pool.tail_za is not None
|
||||
assert pool.tail_len is not None
|
||||
|
||||
|
||||
def test_state_pool_shapes_swa():
|
||||
from dsv4.model.layer_schedule import FFNType
|
||||
config = DSV4Config.flash()
|
||||
spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA,
|
||||
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
|
||||
schema = build_schema(config, spec)
|
||||
pool = StateCachePool(schema, max_requests=8)
|
||||
assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim)
|
||||
assert pool.tail_ka is None # No tail for SWA-only
|
||||
assert pool.tail_len is None
|
||||
|
||||
|
||||
# ---- Manager lifecycle tests ----
|
||||
|
||||
def test_admit_release_recycles_slot():
|
||||
config = DSV4Config.flash()
|
||||
schedule = build_schedule(config)
|
||||
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
|
||||
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
|
||||
s1 = mgr.admit_request()
|
||||
mgr.release_request(s1)
|
||||
s2 = mgr.admit_request()
|
||||
assert s1 == s2 # slot was recycled
|
||||
|
||||
|
||||
def test_admit_exhaustion():
|
||||
config = DSV4Config.flash()
|
||||
schedule = build_schedule(config)
|
||||
mgr = KVCacheManager(config, schedule, max_concurrent_requests=2,
|
||||
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
|
||||
mgr.admit_request()
|
||||
mgr.admit_request()
|
||||
with pytest.raises(RuntimeError, match="concurrent"):
|
||||
mgr.admit_request()
|
||||
|
||||
|
||||
def test_handle_construction_no_alloc():
|
||||
"""acquire() should not allocate GPU memory — critical for cudagraph."""
|
||||
config = DSV4Config.flash()
|
||||
schedule = build_schedule(config)
|
||||
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
|
||||
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
|
||||
slot = mgr.admit_request()
|
||||
torch.cuda.synchronize()
|
||||
before = torch.cuda.memory_allocated()
|
||||
handle = mgr.acquire(
|
||||
layer_idx=0,
|
||||
request_slots=torch.tensor([slot], dtype=torch.int32, device="cuda"),
|
||||
positions=torch.tensor([0], dtype=torch.int32, device="cuda"),
|
||||
request_ids=torch.tensor([0], dtype=torch.int32, device="cuda"),
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
after = torch.cuda.memory_allocated()
|
||||
# Allow small variance from tensor view creation, but no large alloc
|
||||
assert after - before < 1024, f"acquire() allocated {after - before} bytes — breaks cudagraph"
|
||||
assert handle.paged is None # layer 0 is SWA
|
||||
mgr.release_request(slot)
|
||||
|
||||
|
||||
def test_manager_memory_tracking():
|
||||
config = DSV4Config.flash()
|
||||
schedule = build_schedule(config)
|
||||
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
|
||||
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
|
||||
total = mgr.memory_bytes()
|
||||
assert total > 0
|
||||
# Rough sanity: should be in the MB range for this config
|
||||
assert total > 1_000_000 # at least 1 MB
|
||||
assert total < 10_000_000_000 # less than 10 GB
|
||||
|
||||
|
||||
def test_full_flash_stack_construction():
|
||||
"""Construct manager with all 43 Flash layers — pools for every layer."""
|
||||
config = DSV4Config.flash()
|
||||
schedule = build_schedule(config)
|
||||
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
|
||||
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
|
||||
# 2 SWA layers (no paged pool) + 41 compressed layers
|
||||
assert len(mgr.paged_pools) == 43
|
||||
assert mgr.paged_pools[0] is None # SWA
|
||||
assert mgr.paged_pools[1] is None # SWA
|
||||
assert mgr.paged_pools[2] is not None # CSA
|
||||
assert mgr.paged_pools[3] is not None # HCA
|
||||
|
||||
# All state pools present
|
||||
assert len(mgr.state_pools) == 43
|
||||
for i in range(43):
|
||||
assert mgr.state_pools[i] is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user