253 lines
9.3 KiB
Python
253 lines
9.3 KiB
Python
|
|
"""Tests for KV cache: schema, allocator, pools, manager lifecycle."""
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from dsv4.model.config import DSV4Config
|
||
|
|
from dsv4.model.layer_schedule import build_schedule, AttentionType, RouterMode, LayerSpec
|
||
|
|
from dsv4.cache.schema import build_schema, compute_block_budget, BLOCK_SIZE_ORIGINAL_TOKENS
|
||
|
|
from dsv4.cache.allocator import BlockAllocator
|
||
|
|
from dsv4.cache.paged_cache import PagedKVPool
|
||
|
|
from dsv4.cache.state_cache import StateCachePool
|
||
|
|
from dsv4.cache.manager import KVCacheManager
|
||
|
|
|
||
|
|
|
||
|
|
# ---- Schema tests ----
|
||
|
|
|
||
|
|
def test_csa_schema():
|
||
|
|
config = DSV4Config.pro()
|
||
|
|
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
|
||
|
|
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
|
||
|
|
router_mode=RouterMode.HASH)
|
||
|
|
schema = build_schema(config, spec)
|
||
|
|
assert schema.entries_per_block == 32 # 128 / 4
|
||
|
|
assert schema.indexer_entries_per_block == 32
|
||
|
|
assert schema.tail_buffer_size == 3 # m - 1
|
||
|
|
assert schema.swa_window_size == 128
|
||
|
|
|
||
|
|
|
||
|
|
def test_hca_schema():
|
||
|
|
config = DSV4Config.pro()
|
||
|
|
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
|
||
|
|
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
|
||
|
|
router_mode=RouterMode.DENSE)
|
||
|
|
schema = build_schema(config, spec)
|
||
|
|
assert schema.entries_per_block == 1 # 128 / 128
|
||
|
|
assert schema.indexer_entries_per_block == 0
|
||
|
|
assert schema.tail_buffer_size == 127 # m' - 1
|
||
|
|
|
||
|
|
|
||
|
|
def test_swa_schema():
|
||
|
|
config = DSV4Config.flash()
|
||
|
|
spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA,
|
||
|
|
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
|
||
|
|
router_mode=RouterMode.HASH)
|
||
|
|
schema = build_schema(config, spec)
|
||
|
|
assert schema.entries_per_block == 0
|
||
|
|
assert schema.indexer_entries_per_block == 0
|
||
|
|
assert schema.tail_buffer_size == 0
|
||
|
|
assert schema.swa_window_size == 128
|
||
|
|
|
||
|
|
|
||
|
|
def test_schema_from_schedule():
|
||
|
|
"""Every layer in a full schedule produces a valid schema."""
|
||
|
|
config = DSV4Config.flash()
|
||
|
|
schedule = build_schedule(config)
|
||
|
|
for spec in schedule:
|
||
|
|
schema = build_schema(config, spec)
|
||
|
|
assert schema.swa_window_size > 0
|
||
|
|
assert schema.entry_head_dim == config.head_dim
|
||
|
|
if spec.attn == AttentionType.SWA:
|
||
|
|
assert schema.entries_per_block == 0
|
||
|
|
else:
|
||
|
|
assert schema.entries_per_block > 0
|
||
|
|
|
||
|
|
|
||
|
|
def test_block_budget():
|
||
|
|
config = DSV4Config.pro()
|
||
|
|
schedule = build_schedule(config)
|
||
|
|
budget = compute_block_budget(config, schedule, 1_000_000, 16)
|
||
|
|
assert "csa" in budget
|
||
|
|
assert "hca" in budget
|
||
|
|
assert budget["csa"] > budget["hca"] # CSA uses more blocks per request
|
||
|
|
|
||
|
|
|
||
|
|
# ---- Allocator tests ----
|
||
|
|
|
||
|
|
def test_acquire_release_roundtrip():
|
||
|
|
alloc = BlockAllocator(num_total_blocks=1024)
|
||
|
|
a = alloc.acquire(10)
|
||
|
|
assert alloc.num_free == 1014
|
||
|
|
b = alloc.acquire(5)
|
||
|
|
assert alloc.num_free == 1009
|
||
|
|
alloc.release(a)
|
||
|
|
assert alloc.num_free == 1019
|
||
|
|
alloc.release(b)
|
||
|
|
assert alloc.num_free == 1024
|
||
|
|
# Re-acquire works after release.
|
||
|
|
c = alloc.acquire(20)
|
||
|
|
assert alloc.num_free == 1004
|
||
|
|
|
||
|
|
|
||
|
|
def test_oom_raises():
|
||
|
|
alloc = BlockAllocator(num_total_blocks=4)
|
||
|
|
alloc.acquire(4)
|
||
|
|
with pytest.raises(RuntimeError, match="OOM"):
|
||
|
|
alloc.acquire(1)
|
||
|
|
|
||
|
|
|
||
|
|
def test_acquire_returns_unique_ids():
|
||
|
|
alloc = BlockAllocator(num_total_blocks=100)
|
||
|
|
a = alloc.acquire(50)
|
||
|
|
b = alloc.acquire(50)
|
||
|
|
assert len(torch.intersect1d(a, b)) == 0
|
||
|
|
|
||
|
|
|
||
|
|
# ---- Pool shape tests ----
|
||
|
|
|
||
|
|
def test_paged_pool_shapes_csa():
|
||
|
|
from dsv4.model.layer_schedule import FFNType
|
||
|
|
config = DSV4Config.pro()
|
||
|
|
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
|
||
|
|
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
|
||
|
|
schema = build_schema(config, spec)
|
||
|
|
pool = PagedKVPool(schema, num_blocks=16)
|
||
|
|
assert pool.entries_fp8.shape == (16, 32, config.head_dim - config.rope_dim)
|
||
|
|
assert pool.entries_rope.shape == (16, 32, config.rope_dim)
|
||
|
|
assert pool.inv_scale.shape == (16, 32)
|
||
|
|
assert pool.indexer_keys_fp4 is not None
|
||
|
|
assert pool.indexer_keys_fp4.shape[1] == 32
|
||
|
|
|
||
|
|
|
||
|
|
def test_paged_pool_shapes_hca():
|
||
|
|
from dsv4.model.layer_schedule import FFNType
|
||
|
|
config = DSV4Config.pro()
|
||
|
|
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
|
||
|
|
ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
|
||
|
|
schema = build_schema(config, spec)
|
||
|
|
pool = PagedKVPool(schema, num_blocks=256)
|
||
|
|
assert pool.entries_fp8.shape == (256, 1, config.head_dim - config.rope_dim)
|
||
|
|
assert pool.entries_rope.shape == (256, 1, config.rope_dim)
|
||
|
|
assert pool.indexer_keys_fp4 is None
|
||
|
|
|
||
|
|
|
||
|
|
def test_state_pool_shapes_csa():
|
||
|
|
from dsv4.model.layer_schedule import FFNType
|
||
|
|
config = DSV4Config.pro()
|
||
|
|
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
|
||
|
|
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
|
||
|
|
schema = build_schema(config, spec)
|
||
|
|
pool = StateCachePool(schema, max_requests=8)
|
||
|
|
assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim)
|
||
|
|
assert pool.swa_rope.shape == (8, 128, config.rope_dim)
|
||
|
|
assert pool.tail_ka is not None
|
||
|
|
assert pool.tail_kb is not None # CSA has both streams
|
||
|
|
assert pool.tail_len is not None
|
||
|
|
|
||
|
|
|
||
|
|
def test_state_pool_shapes_hca():
|
||
|
|
from dsv4.model.layer_schedule import FFNType
|
||
|
|
config = DSV4Config.pro()
|
||
|
|
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
|
||
|
|
ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
|
||
|
|
schema = build_schema(config, spec)
|
||
|
|
pool = StateCachePool(schema, max_requests=8)
|
||
|
|
assert pool.tail_ka is not None
|
||
|
|
assert pool.tail_kb is None # HCA only one stream
|
||
|
|
assert pool.tail_za is not None
|
||
|
|
assert pool.tail_len is not None
|
||
|
|
|
||
|
|
|
||
|
|
def test_state_pool_shapes_swa():
|
||
|
|
from dsv4.model.layer_schedule import FFNType
|
||
|
|
config = DSV4Config.flash()
|
||
|
|
spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA,
|
||
|
|
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
|
||
|
|
schema = build_schema(config, spec)
|
||
|
|
pool = StateCachePool(schema, max_requests=8)
|
||
|
|
assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim)
|
||
|
|
assert pool.tail_ka is None # No tail for SWA-only
|
||
|
|
assert pool.tail_len is None
|
||
|
|
|
||
|
|
|
||
|
|
# ---- Manager lifecycle tests ----
|
||
|
|
|
||
|
|
def test_admit_release_recycles_slot():
|
||
|
|
config = DSV4Config.flash()
|
||
|
|
schedule = build_schedule(config)
|
||
|
|
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
|
||
|
|
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
|
||
|
|
s1 = mgr.admit_request()
|
||
|
|
mgr.release_request(s1)
|
||
|
|
s2 = mgr.admit_request()
|
||
|
|
assert s1 == s2 # slot was recycled
|
||
|
|
|
||
|
|
|
||
|
|
def test_admit_exhaustion():
|
||
|
|
config = DSV4Config.flash()
|
||
|
|
schedule = build_schedule(config)
|
||
|
|
mgr = KVCacheManager(config, schedule, max_concurrent_requests=2,
|
||
|
|
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
|
||
|
|
mgr.admit_request()
|
||
|
|
mgr.admit_request()
|
||
|
|
with pytest.raises(RuntimeError, match="concurrent"):
|
||
|
|
mgr.admit_request()
|
||
|
|
|
||
|
|
|
||
|
|
def test_handle_construction_no_alloc():
|
||
|
|
"""acquire() should not allocate GPU memory — critical for cudagraph."""
|
||
|
|
config = DSV4Config.flash()
|
||
|
|
schedule = build_schedule(config)
|
||
|
|
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
|
||
|
|
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
|
||
|
|
slot = mgr.admit_request()
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
before = torch.cuda.memory_allocated()
|
||
|
|
handle = mgr.acquire(
|
||
|
|
layer_idx=0,
|
||
|
|
request_slots=torch.tensor([slot], dtype=torch.int32, device="cuda"),
|
||
|
|
positions=torch.tensor([0], dtype=torch.int32, device="cuda"),
|
||
|
|
request_ids=torch.tensor([0], dtype=torch.int32, device="cuda"),
|
||
|
|
)
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
after = torch.cuda.memory_allocated()
|
||
|
|
# Allow small variance from tensor view creation, but no large alloc
|
||
|
|
assert after - before < 1024, f"acquire() allocated {after - before} bytes — breaks cudagraph"
|
||
|
|
assert handle.paged is None # layer 0 is SWA
|
||
|
|
mgr.release_request(slot)
|
||
|
|
|
||
|
|
|
||
|
|
def test_manager_memory_tracking():
|
||
|
|
config = DSV4Config.flash()
|
||
|
|
schedule = build_schedule(config)
|
||
|
|
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
|
||
|
|
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
|
||
|
|
total = mgr.memory_bytes()
|
||
|
|
assert total > 0
|
||
|
|
# Rough sanity: should be in the MB range for this config
|
||
|
|
assert total > 1_000_000 # at least 1 MB
|
||
|
|
assert total < 10_000_000_000 # less than 10 GB
|
||
|
|
|
||
|
|
|
||
|
|
def test_full_flash_stack_construction():
|
||
|
|
"""Construct manager with all 43 Flash layers — pools for every layer."""
|
||
|
|
config = DSV4Config.flash()
|
||
|
|
schedule = build_schedule(config)
|
||
|
|
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
|
||
|
|
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
|
||
|
|
# 2 SWA layers (no paged pool) + 41 compressed layers
|
||
|
|
assert len(mgr.paged_pools) == 43
|
||
|
|
assert mgr.paged_pools[0] is None # SWA
|
||
|
|
assert mgr.paged_pools[1] is None # SWA
|
||
|
|
assert mgr.paged_pools[2] is not None # CSA
|
||
|
|
assert mgr.paged_pools[3] is not None # HCA
|
||
|
|
|
||
|
|
# All state pools present
|
||
|
|
assert len(mgr.state_pools) == 43
|
||
|
|
for i in range(43):
|
||
|
|
assert mgr.state_pools[i] is not None
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
pytest.main([__file__, "-v"])
|