Files
nvfp4-megamoe-kernel/dsv4/model/config.py
biondizzle 66a89859ed Layer dispatch: config, schedule, attention/FFN sub-blocks, TransformerLayer
DSV4Config: frozen dataclass with .flash() / .pro() classmethods.
All architectural constants (dims, heads, MoE params, mHC) in one place.

LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode).
  Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers)
  Pro:   HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers)
  Both:  first 3 MoE layers = hash routing, rest = dense
  validate_schedule() enforces correctness at construction.

AttentionSubBlock: CSA / HCA / SWA variants.
  - Low-rank Q projection (q_down -> q_up)
  - KV down-projection (varies by attn type: 4h/2h/1h)
  - CSA: indexer_q_up + indexer_head_weights
  - Grouped output projection (wo_a + wo_b)
  - Kernel calls are imports (NotImplementedError until kernel lands)
  - No PyTorch fallback paths

FFNSubBlock: MoE + shared expert.
  - Router (hash/dense) mode from LayerSpec
  - Nvfp4MoE + Nvfp4SharedExpert

TransformerLayer: composition of mHC + norm + attention + FFN.
  - Two mHC wrappers (attn + ffn sub-blocks)
  - Two RMSNorm (one per sub-block)
  - Pure orchestration, no learned params on the layer itself

Tests: schedule construction + validation for both variants.
No forward tests yet (depends on FMHA kernel + KV cache).
2026-05-21 23:11:09 +00:00

121 lines
3.5 KiB
Python

"""DSV4 architecture configuration.
One config class for both Flash and Pro — they differ only in numeric
parameters, not structure. Construct via .flash() / .pro() classmethods
or load from a checkpoint's config.json.
"""
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class DSV4Config:
# Identity
variant: str # "flash" | "pro"
vocab_size: int
# Transformer
num_layers: int # Flash: 43, Pro: 61
hidden_size: int # Flash: 4096, Pro: 7168
# Attention (shared across CSA/HCA/SWA)
num_query_heads: int # Flash: 64, Pro: 128
head_dim: int # Both: 512
rope_dim: int # Both: 64 (partial RoPE on last 64 dims)
query_compression_dim: int # Flash: 1024, Pro: 1536
# CSA-specific
csa_compression_ratio: int # m, both: 4
csa_top_k: int # Flash: 512, Pro: 1024
indexer_num_heads: int # Both: 64
indexer_head_dim: int # Both: 128
# HCA-specific
hca_compression_ratio: int # m', both: 128
# SWA branch
sliding_window: int # n_win, both: 128
# Output projection grouping
num_output_groups: int # Flash: 8, Pro: 16
output_group_dim: int # Both: 1024
# MoE
num_routed_experts: int # Flash: 256, Pro: 384
num_shared_experts: int # Both: 1
num_experts_per_tok: int # Both: 6
moe_intermediate_size: int # Flash: 2048, Pro: 3072
num_hash_routing_layers: int # Both: 3
routed_scaling_factor: float # Verify per checkpoint
# mHC
n_hc: int # Both: 4
sinkhorn_iters: int # Both: 20
# Architecture schedule — first N layers use a non-standard attention
# pattern (paper §4.2.1).
# Flash: layers 0-1 = SWA only
# Pro: layers 0-1 = HCA only
num_special_first_layers: int = 2
# Norm
rms_norm_eps: float = 1e-6
@classmethod
def flash(cls) -> "DSV4Config":
return cls(
variant="flash",
vocab_size=128_000,
num_layers=43,
hidden_size=4096,
num_query_heads=64,
head_dim=512,
rope_dim=64,
query_compression_dim=1024,
csa_compression_ratio=4,
csa_top_k=512,
indexer_num_heads=64,
indexer_head_dim=128,
hca_compression_ratio=128,
sliding_window=128,
num_output_groups=8,
output_group_dim=1024,
num_routed_experts=256,
num_shared_experts=1,
num_experts_per_tok=6,
moe_intermediate_size=2048,
num_hash_routing_layers=3,
routed_scaling_factor=2.5,
n_hc=4,
sinkhorn_iters=20,
)
@classmethod
def pro(cls) -> "DSV4Config":
return cls(
variant="pro",
vocab_size=128_000,
num_layers=61,
hidden_size=7168,
num_query_heads=128,
head_dim=512,
rope_dim=64,
query_compression_dim=1536,
csa_compression_ratio=4,
csa_top_k=1024,
indexer_num_heads=64,
indexer_head_dim=128,
hca_compression_ratio=128,
sliding_window=128,
num_output_groups=16,
output_group_dim=1024,
num_routed_experts=384,
num_shared_experts=1,
num_experts_per_tok=6,
moe_intermediate_size=3072,
num_hash_routing_layers=3,
routed_scaling_factor=2.5,
n_hc=4,
sinkhorn_iters=20,
)