DSV4Config: frozen dataclass with .flash() / .pro() classmethods. All architectural constants (dims, heads, MoE params, mHC) in one place. LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode). Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers) Pro: HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers) Both: first 3 MoE layers = hash routing, rest = dense validate_schedule() enforces correctness at construction. AttentionSubBlock: CSA / HCA / SWA variants. - Low-rank Q projection (q_down -> q_up) - KV down-projection (varies by attn type: 4h/2h/1h) - CSA: indexer_q_up + indexer_head_weights - Grouped output projection (wo_a + wo_b) - Kernel calls are imports (NotImplementedError until kernel lands) - No PyTorch fallback paths FFNSubBlock: MoE + shared expert. - Router (hash/dense) mode from LayerSpec - Nvfp4MoE + Nvfp4SharedExpert TransformerLayer: composition of mHC + norm + attention + FFN. - Two mHC wrappers (attn + ffn sub-blocks) - Two RMSNorm (one per sub-block) - Pure orchestration, no learned params on the layer itself Tests: schedule construction + validation for both variants. No forward tests yet (depends on FMHA kernel + KV cache).
86 lines
2.5 KiB
Python
86 lines
2.5 KiB
Python
"""Tests for layer schedule — pure data, no kernels, no tensors."""
|
|
|
|
from dsv4.model.config import DSV4Config
|
|
from dsv4.model.layer_schedule import (
|
|
AttentionType, FFNType, RouterMode,
|
|
LayerSpec, build_schedule, validate_schedule,
|
|
)
|
|
|
|
|
|
def test_flash_schedule():
|
|
config = DSV4Config.flash()
|
|
schedule = build_schedule(config)
|
|
validate_schedule(schedule, config)
|
|
|
|
assert len(schedule) == 43
|
|
|
|
# First two layers: SWA + hash routing
|
|
assert schedule[0].attn == AttentionType.SWA
|
|
assert schedule[1].attn == AttentionType.SWA
|
|
assert schedule[0].router_mode == RouterMode.HASH
|
|
assert schedule[1].router_mode == RouterMode.HASH
|
|
|
|
# Layer 2: CSA + hash routing (last hash layer)
|
|
assert schedule[2].attn == AttentionType.CSA
|
|
assert schedule[2].router_mode == RouterMode.HASH
|
|
|
|
# Layer 3: HCA + dense routing (first dense layer)
|
|
assert schedule[3].attn == AttentionType.HCA
|
|
assert schedule[3].router_mode == RouterMode.DENSE
|
|
|
|
# Alternation continues
|
|
assert schedule[4].attn == AttentionType.CSA
|
|
assert schedule[5].attn == AttentionType.HCA
|
|
|
|
# All layers are MoE
|
|
for spec in schedule:
|
|
assert spec.ffn == FFNType.MOE
|
|
|
|
|
|
def test_pro_schedule():
|
|
config = DSV4Config.pro()
|
|
schedule = build_schedule(config)
|
|
validate_schedule(schedule, config)
|
|
|
|
assert len(schedule) == 61
|
|
|
|
# First two layers: HCA + hash routing
|
|
assert schedule[0].attn == AttentionType.HCA
|
|
assert schedule[1].attn == AttentionType.HCA
|
|
assert schedule[0].router_mode == RouterMode.HASH
|
|
|
|
# Layer 2: CSA + hash routing
|
|
assert schedule[2].attn == AttentionType.CSA
|
|
assert schedule[2].router_mode == RouterMode.HASH
|
|
|
|
# Layer 3: HCA + dense routing
|
|
assert schedule[3].attn == AttentionType.HCA
|
|
assert schedule[3].router_mode == RouterMode.DENSE
|
|
|
|
|
|
def test_layer_spec_frozen():
|
|
"""LayerSpec is frozen — mutation should raise."""
|
|
config = DSV4Config.flash()
|
|
spec = build_schedule(config)[0]
|
|
try:
|
|
spec.attn = AttentionType.HCA
|
|
assert False, "should have raised"
|
|
except AttributeError:
|
|
pass
|
|
|
|
|
|
def test_schedule_indices_match():
|
|
"""Each LayerSpec.layer_idx matches its position in the list."""
|
|
config = DSV4Config.flash()
|
|
schedule = build_schedule(config)
|
|
for i, spec in enumerate(schedule):
|
|
assert spec.layer_idx == i
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_flash_schedule()
|
|
test_pro_schedule()
|
|
test_layer_spec_frozen()
|
|
test_schedule_indices_match()
|
|
print("All schedule tests passed")
|