DSV4Config: frozen dataclass with .flash() / .pro() classmethods. All architectural constants (dims, heads, MoE params, mHC) in one place. LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode). Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers) Pro: HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers) Both: first 3 MoE layers = hash routing, rest = dense validate_schedule() enforces correctness at construction. AttentionSubBlock: CSA / HCA / SWA variants. - Low-rank Q projection (q_down -> q_up) - KV down-projection (varies by attn type: 4h/2h/1h) - CSA: indexer_q_up + indexer_head_weights - Grouped output projection (wo_a + wo_b) - Kernel calls are imports (NotImplementedError until kernel lands) - No PyTorch fallback paths FFNSubBlock: MoE + shared expert. - Router (hash/dense) mode from LayerSpec - Nvfp4MoE + Nvfp4SharedExpert TransformerLayer: composition of mHC + norm + attention + FFN. - Two mHC wrappers (attn + ffn sub-blocks) - Two RMSNorm (one per sub-block) - Pure orchestration, no learned params on the layer itself Tests: schedule construction + validation for both variants. No forward tests yet (depends on FMHA kernel + KV cache).
121 lines
3.5 KiB
Python
121 lines
3.5 KiB
Python
"""DSV4 architecture configuration.
|
|
|
|
One config class for both Flash and Pro — they differ only in numeric
|
|
parameters, not structure. Construct via .flash() / .pro() classmethods
|
|
or load from a checkpoint's config.json.
|
|
"""
|
|
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DSV4Config:
|
|
# Identity
|
|
variant: str # "flash" | "pro"
|
|
vocab_size: int
|
|
|
|
# Transformer
|
|
num_layers: int # Flash: 43, Pro: 61
|
|
hidden_size: int # Flash: 4096, Pro: 7168
|
|
|
|
# Attention (shared across CSA/HCA/SWA)
|
|
num_query_heads: int # Flash: 64, Pro: 128
|
|
head_dim: int # Both: 512
|
|
rope_dim: int # Both: 64 (partial RoPE on last 64 dims)
|
|
query_compression_dim: int # Flash: 1024, Pro: 1536
|
|
|
|
# CSA-specific
|
|
csa_compression_ratio: int # m, both: 4
|
|
csa_top_k: int # Flash: 512, Pro: 1024
|
|
indexer_num_heads: int # Both: 64
|
|
indexer_head_dim: int # Both: 128
|
|
|
|
# HCA-specific
|
|
hca_compression_ratio: int # m', both: 128
|
|
|
|
# SWA branch
|
|
sliding_window: int # n_win, both: 128
|
|
|
|
# Output projection grouping
|
|
num_output_groups: int # Flash: 8, Pro: 16
|
|
output_group_dim: int # Both: 1024
|
|
|
|
# MoE
|
|
num_routed_experts: int # Flash: 256, Pro: 384
|
|
num_shared_experts: int # Both: 1
|
|
num_experts_per_tok: int # Both: 6
|
|
moe_intermediate_size: int # Flash: 2048, Pro: 3072
|
|
num_hash_routing_layers: int # Both: 3
|
|
routed_scaling_factor: float # Verify per checkpoint
|
|
|
|
# mHC
|
|
n_hc: int # Both: 4
|
|
sinkhorn_iters: int # Both: 20
|
|
|
|
# Architecture schedule — first N layers use a non-standard attention
|
|
# pattern (paper §4.2.1).
|
|
# Flash: layers 0-1 = SWA only
|
|
# Pro: layers 0-1 = HCA only
|
|
num_special_first_layers: int = 2
|
|
|
|
# Norm
|
|
rms_norm_eps: float = 1e-6
|
|
|
|
@classmethod
|
|
def flash(cls) -> "DSV4Config":
|
|
return cls(
|
|
variant="flash",
|
|
vocab_size=128_000,
|
|
num_layers=43,
|
|
hidden_size=4096,
|
|
num_query_heads=64,
|
|
head_dim=512,
|
|
rope_dim=64,
|
|
query_compression_dim=1024,
|
|
csa_compression_ratio=4,
|
|
csa_top_k=512,
|
|
indexer_num_heads=64,
|
|
indexer_head_dim=128,
|
|
hca_compression_ratio=128,
|
|
sliding_window=128,
|
|
num_output_groups=8,
|
|
output_group_dim=1024,
|
|
num_routed_experts=256,
|
|
num_shared_experts=1,
|
|
num_experts_per_tok=6,
|
|
moe_intermediate_size=2048,
|
|
num_hash_routing_layers=3,
|
|
routed_scaling_factor=2.5,
|
|
n_hc=4,
|
|
sinkhorn_iters=20,
|
|
)
|
|
|
|
@classmethod
|
|
def pro(cls) -> "DSV4Config":
|
|
return cls(
|
|
variant="pro",
|
|
vocab_size=128_000,
|
|
num_layers=61,
|
|
hidden_size=7168,
|
|
num_query_heads=128,
|
|
head_dim=512,
|
|
rope_dim=64,
|
|
query_compression_dim=1536,
|
|
csa_compression_ratio=4,
|
|
csa_top_k=1024,
|
|
indexer_num_heads=64,
|
|
indexer_head_dim=128,
|
|
hca_compression_ratio=128,
|
|
sliding_window=128,
|
|
num_output_groups=16,
|
|
output_group_dim=1024,
|
|
num_routed_experts=384,
|
|
num_shared_experts=1,
|
|
num_experts_per_tok=6,
|
|
moe_intermediate_size=3072,
|
|
num_hash_routing_layers=3,
|
|
routed_scaling_factor=2.5,
|
|
n_hc=4,
|
|
sinkhorn_iters=20,
|
|
)
|