"""Tests for layer schedule — pure data, no kernels, no tensors.""" from dsv4.model.config import DSV4Config from dsv4.model.layer_schedule import ( AttentionType, FFNType, RouterMode, LayerSpec, build_schedule, validate_schedule, ) def test_flash_schedule(): config = DSV4Config.flash() schedule = build_schedule(config) validate_schedule(schedule, config) assert len(schedule) == 43 # First two layers: SWA + hash routing assert schedule[0].attn == AttentionType.SWA assert schedule[1].attn == AttentionType.SWA assert schedule[0].router_mode == RouterMode.HASH assert schedule[1].router_mode == RouterMode.HASH # Layer 2: CSA + hash routing (last hash layer) assert schedule[2].attn == AttentionType.CSA assert schedule[2].router_mode == RouterMode.HASH # Layer 3: HCA + dense routing (first dense layer) assert schedule[3].attn == AttentionType.HCA assert schedule[3].router_mode == RouterMode.DENSE # Alternation continues assert schedule[4].attn == AttentionType.CSA assert schedule[5].attn == AttentionType.HCA # All layers are MoE for spec in schedule: assert spec.ffn == FFNType.MOE def test_pro_schedule(): config = DSV4Config.pro() schedule = build_schedule(config) validate_schedule(schedule, config) assert len(schedule) == 61 # First two layers: HCA + hash routing assert schedule[0].attn == AttentionType.HCA assert schedule[1].attn == AttentionType.HCA assert schedule[0].router_mode == RouterMode.HASH # Layer 2: CSA + hash routing assert schedule[2].attn == AttentionType.CSA assert schedule[2].router_mode == RouterMode.HASH # Layer 3: HCA + dense routing assert schedule[3].attn == AttentionType.HCA assert schedule[3].router_mode == RouterMode.DENSE def test_layer_spec_frozen(): """LayerSpec is frozen — mutation should raise.""" config = DSV4Config.flash() spec = build_schedule(config)[0] try: spec.attn = AttentionType.HCA assert False, "should have raised" except AttributeError: pass def test_schedule_indices_match(): """Each LayerSpec.layer_idx matches its position in the list.""" config = DSV4Config.flash() schedule = build_schedule(config) for i, spec in enumerate(schedule): assert spec.layer_idx == i if __name__ == "__main__": test_flash_schedule() test_pro_schedule() test_layer_spec_frozen() test_schedule_indices_match() print("All schedule tests passed")