Files
nvfp4-megamoe-kernel/tests/archive/test_cache.py
biondizzle 0ced79ab37 Clean up: archive diagnostics and superseded tests
Kept:
- example10 (CUTLASS LLM, O rescale + final normalize)
- example9 (SSA kv_coord version)
- working_softmax_maybe.py (working softmax snapshot from before the nuke)
- test_fmha_v3_stage_c.py (identity softmax baseline, n=128 cos 0.999998)
- test_fmha_v3.py (Stage A+B baseline)
- layertest.py, cudagraph_test.py (required)
- test_cutedsl.py, test_fp4_roundtrip.py (NVFP4 tests)

Archived: diag_tma_*, example8, test_diag_multitile, test_reference_fmha,
test_ref_minimal, test_tma_coord, test_fmha_v3_diag*, test_fmha_v3_12w,
test_dense_router, test_interleave*, test_fused_step1, test_router,
test_cache, test_compile_custom_op, test_custom_op, test_layer_schedule
2026-05-23 00:17:07 +00:00

253 lines
9.3 KiB
Python

"""Tests for KV cache: schema, allocator, pools, manager lifecycle."""
import torch
import pytest
from dsv4.model.config import DSV4Config
from dsv4.model.layer_schedule import build_schedule, AttentionType, RouterMode, LayerSpec
from dsv4.cache.schema import build_schema, compute_block_budget, BLOCK_SIZE_ORIGINAL_TOKENS
from dsv4.cache.allocator import BlockAllocator
from dsv4.cache.paged_cache import PagedKVPool
from dsv4.cache.state_cache import StateCachePool
from dsv4.cache.manager import KVCacheManager
# ---- Schema tests ----
def test_csa_schema():
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
assert schema.entries_per_block == 32 # 128 / 4
assert schema.indexer_entries_per_block == 32
assert schema.tail_buffer_size == 3 # m - 1
assert schema.swa_window_size == 128
def test_hca_schema():
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
router_mode=RouterMode.DENSE)
schema = build_schema(config, spec)
assert schema.entries_per_block == 1 # 128 / 128
assert schema.indexer_entries_per_block == 0
assert schema.tail_buffer_size == 127 # m' - 1
def test_swa_schema():
config = DSV4Config.flash()
spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA,
ffn=__import__('dsv4.model.layer_schedule', fromlist=['FFNType']).FFNType.MOE,
router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
assert schema.entries_per_block == 0
assert schema.indexer_entries_per_block == 0
assert schema.tail_buffer_size == 0
assert schema.swa_window_size == 128
def test_schema_from_schedule():
"""Every layer in a full schedule produces a valid schema."""
config = DSV4Config.flash()
schedule = build_schedule(config)
for spec in schedule:
schema = build_schema(config, spec)
assert schema.swa_window_size > 0
assert schema.entry_head_dim == config.head_dim
if spec.attn == AttentionType.SWA:
assert schema.entries_per_block == 0
else:
assert schema.entries_per_block > 0
def test_block_budget():
config = DSV4Config.pro()
schedule = build_schedule(config)
budget = compute_block_budget(config, schedule, 1_000_000, 16)
assert "csa" in budget
assert "hca" in budget
assert budget["csa"] > budget["hca"] # CSA uses more blocks per request
# ---- Allocator tests ----
def test_acquire_release_roundtrip():
alloc = BlockAllocator(num_total_blocks=1024)
a = alloc.acquire(10)
assert alloc.num_free == 1014
b = alloc.acquire(5)
assert alloc.num_free == 1009
alloc.release(a)
assert alloc.num_free == 1019
alloc.release(b)
assert alloc.num_free == 1024
# Re-acquire works after release.
c = alloc.acquire(20)
assert alloc.num_free == 1004
def test_oom_raises():
alloc = BlockAllocator(num_total_blocks=4)
alloc.acquire(4)
with pytest.raises(RuntimeError, match="OOM"):
alloc.acquire(1)
def test_acquire_returns_unique_ids():
alloc = BlockAllocator(num_total_blocks=100)
a = alloc.acquire(50)
b = alloc.acquire(50)
assert len(torch.intersect1d(a, b)) == 0
# ---- Pool shape tests ----
def test_paged_pool_shapes_csa():
from dsv4.model.layer_schedule import FFNType
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
pool = PagedKVPool(schema, num_blocks=16)
assert pool.entries_fp8.shape == (16, 32, config.head_dim - config.rope_dim)
assert pool.entries_rope.shape == (16, 32, config.rope_dim)
assert pool.inv_scale.shape == (16, 32)
assert pool.indexer_keys_fp4 is not None
assert pool.indexer_keys_fp4.shape[1] == 32
def test_paged_pool_shapes_hca():
from dsv4.model.layer_schedule import FFNType
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
schema = build_schema(config, spec)
pool = PagedKVPool(schema, num_blocks=256)
assert pool.entries_fp8.shape == (256, 1, config.head_dim - config.rope_dim)
assert pool.entries_rope.shape == (256, 1, config.rope_dim)
assert pool.indexer_keys_fp4 is None
def test_state_pool_shapes_csa():
from dsv4.model.layer_schedule import FFNType
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=2, attn=AttentionType.CSA,
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
pool = StateCachePool(schema, max_requests=8)
assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim)
assert pool.swa_rope.shape == (8, 128, config.rope_dim)
assert pool.tail_ka is not None
assert pool.tail_kb is not None # CSA has both streams
assert pool.tail_len is not None
def test_state_pool_shapes_hca():
from dsv4.model.layer_schedule import FFNType
config = DSV4Config.pro()
spec = LayerSpec(layer_idx=3, attn=AttentionType.HCA,
ffn=FFNType.MOE, router_mode=RouterMode.DENSE)
schema = build_schema(config, spec)
pool = StateCachePool(schema, max_requests=8)
assert pool.tail_ka is not None
assert pool.tail_kb is None # HCA only one stream
assert pool.tail_za is not None
assert pool.tail_len is not None
def test_state_pool_shapes_swa():
from dsv4.model.layer_schedule import FFNType
config = DSV4Config.flash()
spec = LayerSpec(layer_idx=0, attn=AttentionType.SWA,
ffn=FFNType.MOE, router_mode=RouterMode.HASH)
schema = build_schema(config, spec)
pool = StateCachePool(schema, max_requests=8)
assert pool.swa_fp8.shape == (8, 128, config.head_dim - config.rope_dim)
assert pool.tail_ka is None # No tail for SWA-only
assert pool.tail_len is None
# ---- Manager lifecycle tests ----
def test_admit_release_recycles_slot():
config = DSV4Config.flash()
schedule = build_schedule(config)
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
s1 = mgr.admit_request()
mgr.release_request(s1)
s2 = mgr.admit_request()
assert s1 == s2 # slot was recycled
def test_admit_exhaustion():
config = DSV4Config.flash()
schedule = build_schedule(config)
mgr = KVCacheManager(config, schedule, max_concurrent_requests=2,
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
mgr.admit_request()
mgr.admit_request()
with pytest.raises(RuntimeError, match="concurrent"):
mgr.admit_request()
def test_handle_construction_no_alloc():
"""acquire() should not allocate GPU memory — critical for cudagraph."""
config = DSV4Config.flash()
schedule = build_schedule(config)
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
slot = mgr.admit_request()
torch.cuda.synchronize()
before = torch.cuda.memory_allocated()
handle = mgr.acquire(
layer_idx=0,
request_slots=torch.tensor([slot], dtype=torch.int32, device="cuda"),
positions=torch.tensor([0], dtype=torch.int32, device="cuda"),
request_ids=torch.tensor([0], dtype=torch.int32, device="cuda"),
)
torch.cuda.synchronize()
after = torch.cuda.memory_allocated()
# Allow small variance from tensor view creation, but no large alloc
assert after - before < 1024, f"acquire() allocated {after - before} bytes — breaks cudagraph"
assert handle.paged is None # layer 0 is SWA
mgr.release_request(slot)
def test_manager_memory_tracking():
config = DSV4Config.flash()
schedule = build_schedule(config)
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
total = mgr.memory_bytes()
assert total > 0
# Rough sanity: should be in the MB range for this config
assert total > 1_000_000 # at least 1 MB
assert total < 10_000_000_000 # less than 10 GB
def test_full_flash_stack_construction():
"""Construct manager with all 43 Flash layers — pools for every layer."""
config = DSV4Config.flash()
schedule = build_schedule(config)
mgr = KVCacheManager(config, schedule, max_concurrent_requests=4,
num_blocks_per_csa_layer=64, num_blocks_per_hca_layer=64)
# 2 SWA layers (no paged pool) + 41 compressed layers
assert len(mgr.paged_pools) == 43
assert mgr.paged_pools[0] is None # SWA
assert mgr.paged_pools[1] is None # SWA
assert mgr.paged_pools[2] is not None # CSA
assert mgr.paged_pools[3] is not None # HCA
# All state pools present
assert len(mgr.state_pools) == 43
for i in range(43):
assert mgr.state_pools[i] is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])