Files
nvfp4-megamoe-kernel/dsv4/model/config.py

121 lines
3.5 KiB
Python
Raw Normal View History

"""DSV4 architecture configuration.
One config class for both Flash and Pro they differ only in numeric
parameters, not structure. Construct via .flash() / .pro() classmethods
or load from a checkpoint's config.json.
"""
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class DSV4Config:
# Identity
variant: str # "flash" | "pro"
vocab_size: int
# Transformer
num_layers: int # Flash: 43, Pro: 61
hidden_size: int # Flash: 4096, Pro: 7168
# Attention (shared across CSA/HCA/SWA)
num_query_heads: int # Flash: 64, Pro: 128
head_dim: int # Both: 512
rope_dim: int # Both: 64 (partial RoPE on last 64 dims)
query_compression_dim: int # Flash: 1024, Pro: 1536
# CSA-specific
csa_compression_ratio: int # m, both: 4
csa_top_k: int # Flash: 512, Pro: 1024
indexer_num_heads: int # Both: 64
indexer_head_dim: int # Both: 128
# HCA-specific
hca_compression_ratio: int # m', both: 128
# SWA branch
sliding_window: int # n_win, both: 128
# Output projection grouping
num_output_groups: int # Flash: 8, Pro: 16
output_group_dim: int # Both: 1024
# MoE
num_routed_experts: int # Flash: 256, Pro: 384
num_shared_experts: int # Both: 1
num_experts_per_tok: int # Both: 6
moe_intermediate_size: int # Flash: 2048, Pro: 3072
num_hash_routing_layers: int # Both: 3
routed_scaling_factor: float # Verify per checkpoint
# mHC
n_hc: int # Both: 4
sinkhorn_iters: int # Both: 20
# Architecture schedule — first N layers use a non-standard attention
# pattern (paper §4.2.1).
# Flash: layers 0-1 = SWA only
# Pro: layers 0-1 = HCA only
num_special_first_layers: int = 2
# Norm
rms_norm_eps: float = 1e-6
@classmethod
def flash(cls) -> "DSV4Config":
return cls(
variant="flash",
vocab_size=128_000,
num_layers=43,
hidden_size=4096,
num_query_heads=64,
head_dim=512,
rope_dim=64,
query_compression_dim=1024,
csa_compression_ratio=4,
csa_top_k=512,
indexer_num_heads=64,
indexer_head_dim=128,
hca_compression_ratio=128,
sliding_window=128,
num_output_groups=8,
output_group_dim=1024,
num_routed_experts=256,
num_shared_experts=1,
num_experts_per_tok=6,
moe_intermediate_size=2048,
num_hash_routing_layers=3,
routed_scaling_factor=2.5,
n_hc=4,
sinkhorn_iters=20,
)
@classmethod
def pro(cls) -> "DSV4Config":
return cls(
variant="pro",
vocab_size=128_000,
num_layers=61,
hidden_size=7168,
num_query_heads=128,
head_dim=512,
rope_dim=64,
query_compression_dim=1536,
csa_compression_ratio=4,
csa_top_k=1024,
indexer_num_heads=64,
indexer_head_dim=128,
hca_compression_ratio=128,
sliding_window=128,
num_output_groups=16,
output_group_dim=1024,
num_routed_experts=384,
num_shared_experts=1,
num_experts_per_tok=6,
moe_intermediate_size=3072,
num_hash_routing_layers=3,
routed_scaling_factor=2.5,
n_hc=4,
sinkhorn_iters=20,
)