Layer dispatch: config, schedule, attention/FFN sub-blocks, TransformerLayer
DSV4Config: frozen dataclass with .flash() / .pro() classmethods.
All architectural constants (dims, heads, MoE params, mHC) in one place.
LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode).
Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers)
Pro: HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers)
Both: first 3 MoE layers = hash routing, rest = dense
validate_schedule() enforces correctness at construction.
AttentionSubBlock: CSA / HCA / SWA variants.
- Low-rank Q projection (q_down -> q_up)
- KV down-projection (varies by attn type: 4h/2h/1h)
- CSA: indexer_q_up + indexer_head_weights
- Grouped output projection (wo_a + wo_b)
- Kernel calls are imports (NotImplementedError until kernel lands)
- No PyTorch fallback paths
FFNSubBlock: MoE + shared expert.
- Router (hash/dense) mode from LayerSpec
- Nvfp4MoE + Nvfp4SharedExpert
TransformerLayer: composition of mHC + norm + attention + FFN.
- Two mHC wrappers (attn + ffn sub-blocks)
- Two RMSNorm (one per sub-block)
- Pure orchestration, no learned params on the layer itself
Tests: schedule construction + validation for both variants.
No forward tests yet (depends on FMHA kernel + KV cache).
2026-05-21 23:11:09 +00:00
|
|
|
"""DSV4 FFN sub-block — routed MoE + shared expert.
|
|
|
|
|
|
|
|
|
|
The router instance encapsulates hash-vs-dense; this sub-block doesn't
|
|
|
|
|
have to care, it just calls router(x, token_ids) and feeds the result
|
|
|
|
|
to Nvfp4MoE. Shared expert runs in parallel (logically — kernels
|
|
|
|
|
can overlap).
|
|
|
|
|
"""
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from dsv4.layers.router import Router
|
|
|
|
|
from dsv4.layers.moe import Nvfp4MoE
|
|
|
|
|
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
|
|
|
|
from dsv4.model.layer_schedule import LayerSpec, RouterMode
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from dsv4.model.config import DSV4Config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FFNSubBlock:
|
|
|
|
|
def __init__(self, config: "DSV4Config", spec: LayerSpec):
|
|
|
|
|
self.config = config
|
|
|
|
|
self.spec = spec
|
|
|
|
|
|
|
|
|
|
self.router = Router(
|
|
|
|
|
hidden_size=config.hidden_size,
|
|
|
|
|
num_experts=config.num_routed_experts,
|
|
|
|
|
top_k=config.num_experts_per_tok,
|
|
|
|
|
routed_scaling_factor=config.routed_scaling_factor,
|
|
|
|
|
mode="hash" if spec.router_mode == RouterMode.HASH else "dense",
|
|
|
|
|
vocab_size=config.vocab_size if spec.router_mode == RouterMode.HASH else None,
|
|
|
|
|
)
|
|
|
|
|
self.moe = Nvfp4MoE(
|
|
|
|
|
num_experts=config.num_routed_experts,
|
|
|
|
|
hidden_size=config.hidden_size,
|
|
|
|
|
intermediate_size=config.moe_intermediate_size,
|
|
|
|
|
top_k=config.num_experts_per_tok,
|
|
|
|
|
)
|
|
|
|
|
self.shared = Nvfp4SharedExpert(
|
|
|
|
|
hidden_size=config.hidden_size,
|
|
|
|
|
intermediate_size=config.moe_intermediate_size,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x: torch.Tensor, # (T, hidden_size) BF16, post-RMSNorm
|
|
|
|
|
token_ids: torch.Tensor, # (T,) int32 — needed only for hash routing
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
topk_w, topk_ids = self.router(x, token_ids=token_ids)
|
|
|
|
|
routed_out = self.moe.run(x, topk_w, topk_ids)
|
|
|
|
|
shared_out = self.shared.run(x)
|
|
|
|
|
return routed_out + shared_out
|