Files
nvfp4-megamoe-kernel/dsv4/model/layer_schedule.py
biondizzle 66a89859ed Layer dispatch: config, schedule, attention/FFN sub-blocks, TransformerLayer
DSV4Config: frozen dataclass with .flash() / .pro() classmethods.
All architectural constants (dims, heads, MoE params, mHC) in one place.

LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode).
  Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers)
  Pro:   HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers)
  Both:  first 3 MoE layers = hash routing, rest = dense
  validate_schedule() enforces correctness at construction.

AttentionSubBlock: CSA / HCA / SWA variants.
  - Low-rank Q projection (q_down -> q_up)
  - KV down-projection (varies by attn type: 4h/2h/1h)
  - CSA: indexer_q_up + indexer_head_weights
  - Grouped output projection (wo_a + wo_b)
  - Kernel calls are imports (NotImplementedError until kernel lands)
  - No PyTorch fallback paths

FFNSubBlock: MoE + shared expert.
  - Router (hash/dense) mode from LayerSpec
  - Nvfp4MoE + Nvfp4SharedExpert

TransformerLayer: composition of mHC + norm + attention + FFN.
  - Two mHC wrappers (attn + ffn sub-blocks)
  - Two RMSNorm (one per sub-block)
  - Pure orchestration, no learned params on the layer itself

Tests: schedule construction + validation for both variants.
No forward tests yet (depends on FMHA kernel + KV cache).
2026-05-21 23:11:09 +00:00

103 lines
3.7 KiB
Python

"""Per-layer-index architecture schedule for DSV4.
Encodes the attention / FFN / routing pattern from the paper. Kept
separate from layer.py so the KV cache allocator can query it without
instantiating any layers.
"""
from __future__ import annotations
from enum import Enum
from dataclasses import dataclass
from typing import List
from dsv4.model.config import DSV4Config
class AttentionType(Enum):
SWA = "swa" # Sliding window only (Flash layers 0-1)
HCA = "hca" # Heavily Compressed Attention (Pro layers 0-1, alternating later)
CSA = "csa" # Compressed Sparse Attention (alternating from layer 2)
class FFNType(Enum):
MOE = "moe" # DeepSeekMoE — both variants use MoE in every layer
class RouterMode(Enum):
HASH = "hash" # Layers 0-2: deterministic by token ID
DENSE = "dense" # Layers 3+: sqrt(softplus) + top-k
@dataclass(frozen=True)
class LayerSpec:
"""The complete architectural type of a single transformer layer."""
layer_idx: int
attn: AttentionType
ffn: FFNType
router_mode: RouterMode
def build_schedule(config: DSV4Config) -> List[LayerSpec]:
"""Produce the full per-layer schedule for a given config.
Paper §4.2.1:
Flash: first 2 layers = pure SWA, then alternating CSA / HCA.
Pro: first 2 layers = HCA, then alternating CSA / HCA.
Both: first 3 MoE layers use hash routing.
'Alternating CSA / HCA' starts with CSA at layer 2 in both variants.
"""
schedule: List[LayerSpec] = []
first_special = config.num_special_first_layers # 2
for i in range(config.num_layers):
# Attention type
if i < first_special:
if config.variant == "flash":
attn = AttentionType.SWA
else: # pro
attn = AttentionType.HCA
else:
# Layer 2 is CSA, layer 3 is HCA, layer 4 is CSA, ...
attn = AttentionType.CSA if (i - first_special) % 2 == 0 else AttentionType.HCA
# FFN type — DSV4 is MoE in every layer.
ffn = FFNType.MOE
# Router mode — hash for the first N MoE layers.
router = RouterMode.HASH if i < config.num_hash_routing_layers else RouterMode.DENSE
schedule.append(LayerSpec(layer_idx=i, attn=attn, ffn=ffn, router_mode=router))
return schedule
def validate_schedule(schedule: List[LayerSpec], config: DSV4Config) -> None:
"""Sanity checks. Wrong schedule = silent garbage, so be loud here."""
assert len(schedule) == config.num_layers, \
f"schedule has {len(schedule)} layers, config says {config.num_layers}"
# First N hash routing layers
for i in range(config.num_hash_routing_layers):
assert schedule[i].router_mode == RouterMode.HASH, \
f"layer {i} should be HASH-routed"
for i in range(config.num_hash_routing_layers, config.num_layers):
assert schedule[i].router_mode == RouterMode.DENSE, \
f"layer {i} should be DENSE-routed"
# First two layers are special
expected_first = AttentionType.SWA if config.variant == "flash" else AttentionType.HCA
for i in range(config.num_special_first_layers):
assert schedule[i].attn == expected_first, \
f"layer {i} should be {expected_first} for {config.variant}"
# Alternation from layer 2
for i in range(config.num_special_first_layers, config.num_layers):
expected = AttentionType.CSA if (i - config.num_special_first_layers) % 2 == 0 \
else AttentionType.HCA
assert schedule[i].attn == expected, \
f"layer {i} should be {expected}, got {schedule[i].attn}"
# Every layer is MoE
for i, spec in enumerate(schedule):
assert spec.ffn == FFNType.MOE, f"layer {i} should be MOE"