"""Tests for KV cache: schema, allocator, pools, manager lifecycle.""" import torch import pytest from dsv4.model.config import DSV4Config from dsv4.model.layer_schedule import build_schedule, AttentionType, RouterMode, LayerSpec from dsv4.cache.schema import build_schema, compute_block_budget, BLOCK_SIZE_ORIGINAL_TOKENS from dsv4.cache.allocator import BlockAllocator from dsv4.cache.paged_cache import PagedKVPool from dsv4.cache.state_cache import StateCachePool from dsv4.cache.manager import KVCacheManager # ---- Schema tests ---- def test_csa_schema(): config = DSV4Config.pro() spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE, router_mode=RouterMode.HASH) schema = build_schema(config, spec) assert schema.entries_per_block == 32 # 128 / 4 assert schema.indexer_entries_per_block == 32 assert schema.tail_buffer_size == 3 # m - 1 assert schema.swa_window_size == 128 def test_hca_schema(): config = DSV4Config.pro() spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA, ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE, router_mode=RouterMode.DENSE) schema = build_schema(config, spec) assert schema.entries_per_block == 1 # 128 / 128 assert schema.indexer_entries_per_block == 0 assert schema.tail_buffer_size == 127 # m' - 1 def test_swa_schema(): config = DSV4Config.flash() spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA, ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE, router_mode=RouterMode.HASH) schema = build_schema(config, spec) assert schema.entries_per_block == 0 assert schema.indexer_entries_per_block == 0 assert schema.tail_buffer_size == 0 assert schema.swa_window_size == 128 def test_schema_from_schedule(): """Every layer in a full schedule produces a valid schema.""" config = DSV4Config.flash() schedule = build_schedule(config) for spec in schedule: schema = build_schema(config, spec) assert schema.swa_window_size > 0 assert schema.entry_head_dim == config.head_dim if spec.attn == AttentionType.SWA: assert schema.entries_per_block == 0 else: assert schema.entries_per_block > 0 def test_block_budget(): config = DSV4Config.pro() schedule = build_schedule(config) budget = compute_block_budget(config, schedule, 1_000_000, 16) assert "csa" in budget assert "hca" in budget assert budget["csa"] > budget["hca"] # CSA uses more blocks per request # ---- Allocator tests ---- def test_acquire_release_roundtrip(): alloc = BlockAllocator(num_total_blocks=1024) a = alloc.acquire(10) assert alloc.num_free == 1014 b = alloc.acquire(5) assert alloc.num_free == 1009 alloc.release(a) assert alloc.num_free == 1019 alloc.release(b) assert alloc.num_free == 1024 # Re-acquire works after release. c = alloc.acquire(20) assert alloc.num_free == 1004 def test_oom_raises(): alloc = BlockAllocator(num_total_blocks=4) alloc.acquire(4) with pytest.raises(RuntimeError, match="OOM"): alloc.acquire(1) def test_acquire_returns_unique_ids(): alloc = BlockAllocator(num_total_blocks=100) a = alloc.acquire(50) b = alloc.acquire(50) assert len(torch.intersect1d(a, b)) == 0 # ---- Pool shape tests ---- def test_paged_pool_shapes_csa(): from dsv4.model.layer_schedule import FFNType config = DSV4Config.pro() spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, ffn=FFNType.MOE, router_mode=RouterMode.HASH) schema = build_schema(config, spec) pool = PagedKVPool(schema, num_blocks=16) assert pool.entries_fp8.shape == (16, 32, config.head_dim - config.rope_dim) assert pool.entries_rope.shape == (16, 32, config.rope_dim) assert pool.inv_scale.shape == (16, 32) assert pool.indexer_keys_fp4 is not None assert pool.indexer_keys_fp4.shape[1] == 32 def test_paged_pool_shapes_hca(): from dsv4.model.layer_schedule import FFNType config = DSV4Config.pro() spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA, ffn=FFNType.MOE, router_mode=RouterMode.DENSE) schema = build_schema(config, spec) pool = PagedKVPool(schema, num_blocks=256) assert pool.entries_fp8.shape == (256, 1, config.head_dim - config.rope_dim) assert pool.entries_rope.shape == (256, 1, config.rope_dim) assert pool.indexer_keys_fp4 is None def test_state_pool_shapes_csa(): from dsv4.model.layer_schedule import FFNType config = DSV4Config.pro() spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA, ffn=FFNType.MOE, router_mode=RouterMode.HASH) schema = build_schema(config, spec) pool = StateCachePool(schema, max_requests=8) assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim) assert pool.swa_rope.shape == (8, 128, config.rope_dim) assert pool.tail_ka is not None assert pool.tail_kb is not None # CSA has both streams assert pool.tail_len is not None def test_state_pool_shapes_hca(): from dsv4.model.layer_schedule import FFNType config = DSV4Config.pro() spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA, ffn=FFNType.MOE, router_mode=RouterMode.DENSE) schema = build_schema(config, spec) pool = StateCachePool(schema, max_requests=8) assert pool.tail_ka is not None assert pool.tail_kb is None # HCA only one stream assert pool.tail_za is not None assert pool.tail_len is not None def test_state_pool_shapes_swa(): from dsv4.model.layer_schedule import FFNType config = DSV4Config.flash() spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA, ffn=FFNType.MOE, router_mode=RouterMode.HASH) schema = build_schema(config, spec) pool = StateCachePool(schema, max_requests=8) assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim) assert pool.tail_ka is None # No tail for SWA-only assert pool.tail_len is None # ---- Manager lifecycle tests ---- def test_admit_release_recycles_slot(): config = DSV4Config.flash() schedule = build_schedule(config) mgr = KVCacheManager(config, schedule, max_concurrent_requests=4, num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64) s1 = mgr.admit_request() mgr.release_request(s1) s2 = mgr.admit_request() assert s1 == s2 # slot was recycled def test_admit_exhaustion(): config = DSV4Config.flash() schedule = build_schedule(config) mgr = KVCacheManager(config, schedule, max_concurrent_requests=2, num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64) mgr.admit_request() mgr.admit_request() with pytest.raises(RuntimeError, match="concurrent"): mgr.admit_request() def test_handle_construction_no_alloc(): """acquire() should not allocate GPU memory — critical for cudagraph.""" config = DSV4Config.flash() schedule = build_schedule(config) mgr = KVCacheManager(config, schedule, max_concurrent_requests=4, num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64) slot = mgr.admit_request() torch.cuda.synchronize() before = torch.cuda.memory_allocated() handle = mgr.acquire( layer_idx=0, request_slots=torch.tensor([slot], dtype=torch.int32, device="cuda"), positions=torch.tensor([0], dtype=torch.int32, device="cuda"), request_ids=torch.tensor([0], dtype=torch.int32, device="cuda"), ) torch.cuda.synchronize() after = torch.cuda.memory_allocated() # Allow small variance from tensor view creation, but no large alloc assert after - before < 1024, f"acquire() allocated {after - before} bytes — breaks cudagraph" assert handle.paged is None # layer 0 is SWA mgr.release_request(slot) def test_manager_memory_tracking(): config = DSV4Config.flash() schedule = build_schedule(config) mgr = KVCacheManager(config, schedule, max_concurrent_requests=4, num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64) total = mgr.memory_bytes() assert total > 0 # Rough sanity: should be in the MB range for this config assert total > 1_000_000 # at least 1 MB assert total < 10_000_000_000 # less than 10 GB def test_full_flash_stack_construction(): """Construct manager with all 43 Flash layers — pools for every layer.""" config = DSV4Config.flash() schedule = build_schedule(config) mgr = KVCacheManager(config, schedule, max_concurrent_requests=4, num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64) # 2 SWA layers (no paged pool) + 41 compressed layers assert len(mgr.paged_pools) == 43 assert mgr.paged_pools[0] is None # SWA assert mgr.paged_pools[1] is None # SWA assert mgr.paged_pools[2] is not None # CSA assert mgr.paged_pools[3] is not None # HCA # All state pools present assert len(mgr.state_pools) == 43 for i in range(43): assert mgr.state_pools[i] is not None if __name__ == "__main__": pytest.main([__file__, "-v"])