Files
nvfp4-megamoe-kernel/tests/unit/test_cache.py

253 lines
9.3 KiB
Python
Raw Normal View History

KV Cache: schema, allocator, pools, manager, append_swa kernel Complete KV cache substrate for DSV4 inference: schema.py: Per-layer cache shape derived from LayerSpec. - CSA: 32 entries/block, 32 indexer entries, tail=3 - HCA: 1 entry/block, no indexer, tail=127 - SWA: no classical pool, no tail - BLOCK_SIZE_ORIGINAL_TOKENS=128 (lcm of compression ratios) - compute_block_budget() for allocator sizing allocator.py: Fixed-size block free-list. - GPU stack with pinned host top pointer - acquire/release between graph captures only - OOM raises on exhaustion paged_cache.py: Per-layer classical KV storage. - FP8 (uint8) for non-RoPE dims, BF16 for RoPE dims (paper 2.3.4) - Per-entry inverse scale for FP8 dequant - FP4 indexer keys for CSA layers (NVFP4 scheme) - memory_bytes() tracking state_cache.py: Per-layer SWA window + tail buffer. - Ring buffer with position tracking (swa_head, swa_pos) - CSA: dual streams (ka/za/kb/zb) for overlapping compression - HCA: single stream (ka/za only) - SWA: no tail buffer - reset_slot() for request completion handle.py: LayerCacheHandle — typed per-call view. - write_swa(), read_swa_view(), read_classical_view(), read_indexer_view() - No GPU allocation in acquire() — 0 bytes delta (cudagraph safe) - SWAView/ClassicalView/IndexerView dataclasses for kernel signatures manager.py: KVCacheManager — owns everything. - Per-layer schema, pool, and allocator construction - admit_request()/release_request() lifecycle - allocate_block() for compression flush - acquire() returns LayerCacheHandle (zero-alloc) append_swa.cu: Native kernel for SWA writes. - One block per token, 128 threads per block - Warp-level amax reduction, BF16->FP8 E4M3 quantization - Atomic ring buffer head increment - FP8/BF16 split write + inv_scale + position metadata - FP8 round-trip: <3.6% relative error - RoPE half: exact match (no quantization) All tests pass on B200: - Schema correctness for CSA/HCA/SWA - Allocator acquire/release/OOM - Pool shapes match architecture spec - Manager lifecycle (admit/release/recycle/exhaustion) - Zero-alloc acquire() (cudagraph safe) - append_swa kernel: positions, RoPE exact, FP8 quality, wrap-around, multi-request isolation
2026-05-22 00:08:38 +00:00
"""Tests for KV cache: schema, allocator, pools, manager lifecycle."""
import torch
import pytest
from dsv4.model.config import DSV4Config
from dsv4.model.layer_schedule import build_schedule, AttentionType, RouterMode, LayerSpec
from dsv4.cache.schema import build_schema, compute_block_budget, BLOCK_SIZE_ORIGINAL_TOKENS
from dsv4.cache.allocator import BlockAllocator
from dsv4.cache.paged_cache import PagedKVPool
from dsv4.cache.state_cache import StateCachePool
from dsv4.cache.manager import KVCacheManager
# ---- Schema tests ----
def test_csa_schema():
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
assert schema.entries_per_block == 32 # 128 / 4
assert schema.indexer_entries_per_block == 32
assert schema.tail_buffer_size == 3 # m - 1
assert schema.swa_window_size == 128
def test_hca_schema():
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
router_mode=RouterMode.DENSE)
schema = build_schema(config, spec)
assert schema.entries_per_block == 1 # 128 / 128
assert schema.indexer_entries_per_block == 0
assert schema.tail_buffer_size == 127 # m' - 1
def test_swa_schema():
config = DSV4Config.flash()
spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA,
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
assert schema.entries_per_block == 0
assert schema.indexer_entries_per_block == 0
assert schema.tail_buffer_size == 0
assert schema.swa_window_size == 128
def test_schema_from_schedule():
"""Every layer in a full schedule produces a valid schema."""
config = DSV4Config.flash()
schedule = build_schedule(config)
for spec in schedule:
schema = build_schema(config, spec)
assert schema.swa_window_size > 0
assert schema.entry_head_dim == config.head_dim
if spec.attn == AttentionType.SWA:
assert schema.entries_per_block == 0
else:
assert schema.entries_per_block > 0
def test_block_budget():
config = DSV4Config.pro()
schedule = build_schedule(config)
budget = compute_block_budget(config, schedule, 1_000_000, 16)
assert "csa" in budget
assert "hca" in budget
assert budget["csa"] > budget["hca"] # CSA uses more blocks per request
# ---- Allocator tests ----
def test_acquire_release_roundtrip():
alloc = BlockAllocator(num_total_blocks=1024)
a = alloc.acquire(10)
assert alloc.num_free == 1014
b = alloc.acquire(5)
assert alloc.num_free == 1009
alloc.release(a)
assert alloc.num_free == 1019
alloc.release(b)
assert alloc.num_free == 1024
# Re-acquire works after release.
c = alloc.acquire(20)
assert alloc.num_free == 1004
def test_oom_raises():
alloc = BlockAllocator(num_total_blocks=4)
alloc.acquire(4)
with pytest.raises(RuntimeError, match="OOM"):
alloc.acquire(1)
def test_acquire_returns_unique_ids():
alloc = BlockAllocator(num_total_blocks=100)
a = alloc.acquire(50)
b = alloc.acquire(50)
assert len(torch.intersect1d(a, b)) == 0
# ---- Pool shape tests ----
def test_paged_pool_shapes_csa():
from dsv4.model.layer_schedule import FFNType
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
pool = PagedKVPool(schema, num_blocks=16)
assert pool.entries_fp8.shape == (16, 32, config.head_dim - config.rope_dim)
assert pool.entries_rope.shape == (16, 32, config.rope_dim)
assert pool.inv_scale.shape == (16, 32)
assert pool.indexer_keys_fp4 is not None
assert pool.indexer_keys_fp4.shape[1] == 32
def test_paged_pool_shapes_hca():
from dsv4.model.layer_schedule import FFNType
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
schema = build_schema(config, spec)
pool = PagedKVPool(schema, num_blocks=256)
assert pool.entries_fp8.shape == (256, 1, config.head_dim - config.rope_dim)
assert pool.entries_rope.shape == (256, 1, config.rope_dim)
assert pool.indexer_keys_fp4 is None
def test_state_pool_shapes_csa():
from dsv4.model.layer_schedule import FFNType
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
pool = StateCachePool(schema, max_requests=8)
assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim)
assert pool.swa_rope.shape == (8, 128, config.rope_dim)
assert pool.tail_ka is not None
assert pool.tail_kb is not None # CSA has both streams
assert pool.tail_len is not None
def test_state_pool_shapes_hca():
from dsv4.model.layer_schedule import FFNType
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
schema = build_schema(config, spec)
pool = StateCachePool(schema, max_requests=8)
assert pool.tail_ka is not None
assert pool.tail_kb is None # HCA only one stream
assert pool.tail_za is not None
assert pool.tail_len is not None
def test_state_pool_shapes_swa():
from dsv4.model.layer_schedule import FFNType
config = DSV4Config.flash()
spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA,
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
pool = StateCachePool(schema, max_requests=8)
assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim)
assert pool.tail_ka is None # No tail for SWA-only
assert pool.tail_len is None
# ---- Manager lifecycle tests ----
def test_admit_release_recycles_slot():
config = DSV4Config.flash()
schedule = build_schedule(config)
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
s1 = mgr.admit_request()
mgr.release_request(s1)
s2 = mgr.admit_request()
assert s1 == s2 # slot was recycled
def test_admit_exhaustion():
config = DSV4Config.flash()
schedule = build_schedule(config)
mgr = KVCacheManager(config, schedule, max_concurrent_requests=2,
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
mgr.admit_request()
mgr.admit_request()
with pytest.raises(RuntimeError, match="concurrent"):
mgr.admit_request()
def test_handle_construction_no_alloc():
"""acquire() should not allocate GPU memory — critical for cudagraph."""
config = DSV4Config.flash()
schedule = build_schedule(config)
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
slot = mgr.admit_request()
torch.cuda.synchronize()
before = torch.cuda.memory_allocated()
handle = mgr.acquire(
layer_idx=0,
request_slots=torch.tensor([slot], dtype=torch.int32, device="cuda"),
positions=torch.tensor([0], dtype=torch.int32, device="cuda"),
request_ids=torch.tensor([0], dtype=torch.int32, device="cuda"),
)
torch.cuda.synchronize()
after = torch.cuda.memory_allocated()
# Allow small variance from tensor view creation, but no large alloc
assert after - before < 1024, f"acquire() allocated {after - before} bytes — breaks cudagraph"
assert handle.paged is None # layer 0 is SWA
mgr.release_request(slot)
def test_manager_memory_tracking():
config = DSV4Config.flash()
schedule = build_schedule(config)
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
total = mgr.memory_bytes()
assert total > 0
# Rough sanity: should be in the MB range for this config
assert total > 1_000_000 # at least 1 MB
assert total < 10_000_000_000 # less than 10 GB
def test_full_flash_stack_construction():
"""Construct manager with all 43 Flash layers — pools for every layer."""
config = DSV4Config.flash()
schedule = build_schedule(config)
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
# 2 SWA layers (no paged pool) + 41 compressed layers
assert len(mgr.paged_pools) == 43
assert mgr.paged_pools[0] is None # SWA
assert mgr.paged_pools[1] is None # SWA
assert mgr.paged_pools[2] is not None # CSA
assert mgr.paged_pools[3] is not None # HCA
# All state pools present
assert len(mgr.state_pools) == 43
for i in range(43):
assert mgr.state_pools[i] is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])