Files
nvfp4-megamoe-kernel/dsv4/layers/ffn.py

54 lines
1.9 KiB
Python
Raw Normal View History

"""DSV4 FFN sub-block — routed MoE + shared expert.
The router instance encapsulates hash-vs-dense; this sub-block doesn't
have to care, it just calls router(x, token_ids) and feeds the result
to Nvfp4MoE. Shared expert runs in parallel (logically kernels
can overlap).
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from dsv4.layers.router import Router
from dsv4.layers.moe import Nvfp4MoE
from dsv4.layers.shared_expert import Nvfp4SharedExpert
from dsv4.model.layer_schedule import LayerSpec, RouterMode
if TYPE_CHECKING:
from dsv4.model.config import DSV4Config
class FFNSubBlock:
def __init__(self, config: "DSV4Config", spec: LayerSpec):
self.config = config
self.spec = spec
self.router = Router(
hidden_size=config.hidden_size,
num_experts=config.num_routed_experts,
top_k=config.num_experts_per_tok,
routed_scaling_factor=config.routed_scaling_factor,
mode="hash" if spec.router_mode == RouterMode.HASH else "dense",
vocab_size=config.vocab_size if spec.router_mode == RouterMode.HASH else None,
)
self.moe = Nvfp4MoE(
num_experts=config.num_routed_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
top_k=config.num_experts_per_tok,
)
self.shared = Nvfp4SharedExpert(
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
)
def forward(
self,
x: torch.Tensor, # (T, hidden_size) BF16, post-RMSNorm
token_ids: torch.Tensor, # (T,) int32 — needed only for hash routing
) -> torch.Tensor:
topk_w, topk_ids = self.router(x, token_ids=token_ids)
routed_out = self.moe.run(x, topk_w, topk_ids)
shared_out = self.shared.run(x)
return routed_out + shared_out