DSV4Config: frozen dataclass with .flash() / .pro() classmethods. All architectural constants (dims, heads, MoE params, mHC) in one place. LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode). Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers) Pro: HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers) Both: first 3 MoE layers = hash routing, rest = dense validate_schedule() enforces correctness at construction. AttentionSubBlock: CSA / HCA / SWA variants. - Low-rank Q projection (q_down -> q_up) - KV down-projection (varies by attn type: 4h/2h/1h) - CSA: indexer_q_up + indexer_head_weights - Grouped output projection (wo_a + wo_b) - Kernel calls are imports (NotImplementedError until kernel lands) - No PyTorch fallback paths FFNSubBlock: MoE + shared expert. - Router (hash/dense) mode from LayerSpec - Nvfp4MoE + Nvfp4SharedExpert TransformerLayer: composition of mHC + norm + attention + FFN. - Two mHC wrappers (attn + ffn sub-blocks) - Two RMSNorm (one per sub-block) - Pure orchestration, no learned params on the layer itself Tests: schedule construction + validation for both variants. No forward tests yet (depends on FMHA kernel + KV cache).
2.5 KiB
2.5 KiB