[MoE] Enable Shared/Routed Overlap For Latent MoE (Nemotron-H) (#32790)
Signed-off-by: dafrimi <dafrimi@nvidia.com>
This commit is contained in:
162
tests/kernels/moe/test_shared_fused_moe_routed_transform.py
Normal file
162
tests/kernels/moe/test_shared_fused_moe_routed_transform.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Tests for SharedFusedMoE with routed_input_transform.
|
||||||
|
|
||||||
|
Verifies that applying routed_input_transform inside SharedFusedMoE
|
||||||
|
produces the same results as applying the transform manually outside.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleLinear(nn.Module):
|
||||||
|
"""A simple linear transform mimicking latent projection in latent MoE."""
|
||||||
|
|
||||||
|
def __init__(self, in_features: int, out_features: int, dtype: torch.dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.randn(out_features, in_features, device="cuda", dtype=dtype) / 10
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return nn.functional.linear(x, self.weight)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleSharedExperts(nn.Module):
|
||||||
|
"""A simple 2-layer MLP mimicking shared experts."""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.up = nn.Linear(
|
||||||
|
hidden_size, intermediate_size * 2, bias=False, device="cuda", dtype=dtype
|
||||||
|
)
|
||||||
|
self.down = nn.Linear(
|
||||||
|
intermediate_size, hidden_size, bias=False, device="cuda", dtype=dtype
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
self.up.weight.div_(10)
|
||||||
|
self.down.weight.div_(10)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
gate_up = self.up(x)
|
||||||
|
gate, up = gate_up.chunk(2, dim=-1)
|
||||||
|
return self.down(nn.functional.silu(gate) * up)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_cuda():
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
pytest.skip("CUDA not available")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", [1, 32])
|
||||||
|
@pytest.mark.parametrize("hidden_size,latent_size", [(256, 128), (128, 64)])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
def test_routed_input_transform_inside_vs_outside(
|
||||||
|
num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
latent_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
dist_init,
|
||||||
|
workspace_init,
|
||||||
|
):
|
||||||
|
"""Compare SharedFusedMoE with transform inside vs manually applying outside.
|
||||||
|
Method A (inside): SharedFusedMoE with routed_input_transform
|
||||||
|
Method B (outside): Manually transform, then SharedFusedMoE without transform
|
||||||
|
"""
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
num_experts = 8
|
||||||
|
top_k = 2
|
||||||
|
intermediate_size = hidden_size * 2
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.compilation_config.static_forward_context = dict()
|
||||||
|
|
||||||
|
shared_experts = SimpleSharedExperts(hidden_size, intermediate_size, dtype)
|
||||||
|
routed_transform = SimpleLinear(hidden_size, latent_size, dtype)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
# Method A: SharedFusedMoE WITH routed_input_transform
|
||||||
|
moe_with_transform = SharedFusedMoE(
|
||||||
|
shared_experts=shared_experts,
|
||||||
|
routed_input_transform=routed_transform,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
hidden_size=latent_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
reduce_results=False,
|
||||||
|
renormalize=True,
|
||||||
|
params_dtype=dtype,
|
||||||
|
tp_size=1,
|
||||||
|
dp_size=1,
|
||||||
|
pcp_size=1,
|
||||||
|
prefix="moe_with_transform",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Method B: SharedFusedMoE WITHOUT routed_input_transform
|
||||||
|
# Note: shared_experts=None because when transform is done outside,
|
||||||
|
moe_without_transform = SharedFusedMoE(
|
||||||
|
shared_experts=None,
|
||||||
|
routed_input_transform=None,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
hidden_size=latent_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
reduce_results=False,
|
||||||
|
renormalize=True,
|
||||||
|
params_dtype=dtype,
|
||||||
|
tp_size=1,
|
||||||
|
dp_size=1,
|
||||||
|
pcp_size=1,
|
||||||
|
prefix="moe_without_transform",
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
moe_without_transform.w13_weight.copy_(moe_with_transform.w13_weight)
|
||||||
|
moe_without_transform.w2_weight.copy_(moe_with_transform.w2_weight)
|
||||||
|
|
||||||
|
moe_with_transform.quant_method.process_weights_after_loading(
|
||||||
|
moe_with_transform
|
||||||
|
)
|
||||||
|
moe_without_transform.quant_method.process_weights_after_loading(
|
||||||
|
moe_without_transform
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype)
|
||||||
|
router_logits = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
with set_forward_context(None, vllm_config, num_tokens=num_tokens):
|
||||||
|
shared_out_A, routed_out_A = moe_with_transform(
|
||||||
|
hidden_states, router_logits
|
||||||
|
)
|
||||||
|
|
||||||
|
transformed_hidden = routed_transform(hidden_states)
|
||||||
|
shared_out_B, routed_out_B = moe_without_transform(
|
||||||
|
transformed_hidden, router_logits
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
routed_out_A,
|
||||||
|
routed_out_B,
|
||||||
|
atol=1e-3,
|
||||||
|
rtol=1e-3,
|
||||||
|
msg="Routed output should match: transform inside vs outside",
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_shared_out = shared_experts(hidden_states)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
shared_out_A,
|
||||||
|
expected_shared_out,
|
||||||
|
atol=1e-3,
|
||||||
|
rtol=1e-3,
|
||||||
|
)
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Generator, Iterable
|
||||||
from contextlib import nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, cast, get_args, overload
|
from typing import Literal, cast, get_args, overload
|
||||||
|
|
||||||
@@ -351,6 +351,10 @@ class FusedMoE(CustomOp):
|
|||||||
"Enabled separate cuda stream for MoE shared_experts", scope="local"
|
"Enabled separate cuda stream for MoE shared_experts", scope="local"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For latent MoE: stores original hidden_states before routed_input_transform
|
||||||
|
# so shared_experts can use it for cloning (they need original dimension)
|
||||||
|
self._shared_experts_input: torch.Tensor | None = None
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
@@ -664,6 +668,39 @@ class FusedMoE(CustomOp):
|
|||||||
def gate(self) -> torch.nn.Module | None:
|
def gate(self) -> torch.nn.Module | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Hook to transform hidden_states before passing to routed experts.
|
||||||
|
For latent MoE: transforms [S, hidden_size] → [S, moe_latent_size].
|
||||||
|
The original hidden_states is saved in _shared_experts_input so
|
||||||
|
shared_experts still receive the original [S, hidden_size].
|
||||||
|
|
||||||
|
Override in subclasses (e.g., SharedFusedMoE) for latent MoE.
|
||||||
|
"""
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _set_shared_experts_input(
|
||||||
|
self, value: torch.Tensor | None
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
|
"""Context manager to safely set/clear _shared_experts_input."""
|
||||||
|
self._shared_experts_input = value
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._shared_experts_input = None
|
||||||
|
|
||||||
|
def _get_shared_experts_input(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Get input for shared experts.
|
||||||
|
|
||||||
|
For latent MoE: shared_experts need original [S, hidden_size],
|
||||||
|
not the transformed [S, latent_size] used by routed experts.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self._shared_experts_input
|
||||||
|
if self._shared_experts_input is not None
|
||||||
|
else hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tp_size(self):
|
def tp_size(self):
|
||||||
return self.moe_parallel_config.tp_size
|
return self.moe_parallel_config.tp_size
|
||||||
@@ -855,9 +892,11 @@ class FusedMoE(CustomOp):
|
|||||||
if use_shared_experts_stream:
|
if use_shared_experts_stream:
|
||||||
assert self.shared_experts_stream is not None
|
assert self.shared_experts_stream is not None
|
||||||
|
|
||||||
|
shared_experts_input = self._get_shared_experts_input(hidden_states)
|
||||||
|
|
||||||
# Clone BEFORE switching streams to avoid race condition
|
# Clone BEFORE switching streams to avoid race condition
|
||||||
# where routed_expert kernel may mutate hidden_states.
|
# where routed_expert kernel may mutate hidden_states.
|
||||||
hidden_states_clone = hidden_states.clone()
|
hidden_states_clone = shared_experts_input.clone()
|
||||||
|
|
||||||
# Record that the clone will be used by shared_experts_stream
|
# Record that the clone will be used by shared_experts_stream
|
||||||
# to avoid gc issue from deallocation of hidden_states_clone
|
# to avoid gc issue from deallocation of hidden_states_clone
|
||||||
@@ -1537,11 +1576,20 @@ class FusedMoE(CustomOp):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
og_hidden_states = hidden_states.shape[-1]
|
# For latent MoE: save ORIGINAL hidden_states before transform
|
||||||
if self.hidden_size != og_hidden_states:
|
# (shared_experts need original dimension, routed experts use transformed)
|
||||||
|
original_hidden_states = hidden_states
|
||||||
|
original_hidden_dim = hidden_states.shape[-1]
|
||||||
|
|
||||||
|
# Apply transform for routed experts (e.g., latent projection for latent MoE)
|
||||||
|
hidden_states = self.apply_routed_input_transform(hidden_states)
|
||||||
|
|
||||||
|
# This is the dimension after transform (for routed expert output slicing)
|
||||||
|
transformed_hidden_dim = hidden_states.shape[-1]
|
||||||
|
if self.hidden_size != transformed_hidden_dim:
|
||||||
hidden_states = F.pad(
|
hidden_states = F.pad(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
(0, self.hidden_size - og_hidden_states),
|
(0, self.hidden_size - transformed_hidden_dim),
|
||||||
mode="constant",
|
mode="constant",
|
||||||
value=0.0,
|
value=0.0,
|
||||||
)
|
)
|
||||||
@@ -1576,22 +1624,31 @@ class FusedMoE(CustomOp):
|
|||||||
fused_output = torch.ops.vllm.moe_forward(
|
fused_output = torch.ops.vllm.moe_forward(
|
||||||
hidden_states, router_logits, encode_layer_name()
|
hidden_states, router_logits, encode_layer_name()
|
||||||
)
|
)
|
||||||
return reduce_output(fused_output)[..., :og_hidden_states]
|
return reduce_output(fused_output)[..., :transformed_hidden_dim]
|
||||||
else:
|
else:
|
||||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||||
# will switch to using the moe_forward custom op.
|
# will switch to using the moe_forward custom op.
|
||||||
# Note: CPU doesn't require wrapped forward_impl.
|
# Note: CPU doesn't require wrapped forward_impl.
|
||||||
shared_output, fused_output = self.forward_impl(
|
with self._set_shared_experts_input(original_hidden_states):
|
||||||
hidden_states, router_logits
|
shared_output, fused_output = self.forward_impl(
|
||||||
)
|
hidden_states, router_logits
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
# Custom op handles setting/clearing _shared_experts_input internally
|
||||||
|
# We pass original tensor for shared experts (not transformed)
|
||||||
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
|
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
|
||||||
hidden_states, router_logits, encode_layer_name()
|
hidden_states,
|
||||||
|
router_logits,
|
||||||
|
encode_layer_name(),
|
||||||
|
original_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# shared_output uses original dimension (before transform)
|
||||||
|
# fused_output uses transformed dimension (after transform)
|
||||||
return (
|
return (
|
||||||
reduce_output(shared_output)[..., :og_hidden_states],
|
reduce_output(shared_output)[..., :original_hidden_dim],
|
||||||
reduce_output(fused_output)[..., :og_hidden_states],
|
reduce_output(fused_output)[..., :transformed_hidden_dim],
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1831,7 +1888,8 @@ class FusedMoE(CustomOp):
|
|||||||
# because matrix multiply maybe modify the hidden_states.
|
# because matrix multiply maybe modify the hidden_states.
|
||||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||||
assert self.shared_experts is not None
|
assert self.shared_experts is not None
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_input = self._get_shared_experts_input(hidden_states)
|
||||||
|
shared_output = self.shared_experts(shared_input)
|
||||||
|
|
||||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||||
@@ -2021,19 +2079,34 @@ def moe_forward_shared(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
|
shared_experts_input: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
self = get_layer_from_name(layer_name)
|
self = get_layer_from_name(layer_name)
|
||||||
assert self.shared_experts is not None
|
assert self.shared_experts is not None
|
||||||
return self.forward_impl(hidden_states, router_logits)
|
|
||||||
|
# Set here because torch.compile skips forward_native() setup code
|
||||||
|
# and calls this op directly. forward_impl() reads from this var.
|
||||||
|
with self._set_shared_experts_input(shared_experts_input):
|
||||||
|
return self.forward_impl(hidden_states, router_logits)
|
||||||
|
|
||||||
|
|
||||||
def moe_forward_shared_fake(
|
def moe_forward_shared_fake(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
|
shared_experts_input: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
shared_out = torch.empty_like(hidden_states)
|
# Output shapes:
|
||||||
|
# - fused_out: same as hidden_states (routed experts use transformed size)
|
||||||
|
# - shared_out: same as shared_experts_input if provided, else same as hidden_states
|
||||||
|
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
||||||
fused_out = torch.empty_like(hidden_states)
|
fused_out = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
if shared_experts_input is not None:
|
||||||
|
shared_out = torch.empty_like(shared_experts_input)
|
||||||
|
else:
|
||||||
|
shared_out = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
return shared_out, fused_out
|
return shared_out, fused_out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,10 +23,12 @@ class SharedFusedMoE(FusedMoE):
|
|||||||
shared_experts: torch.nn.Module | None,
|
shared_experts: torch.nn.Module | None,
|
||||||
gate: torch.nn.Module | None = None,
|
gate: torch.nn.Module | None = None,
|
||||||
use_overlapped: bool = True,
|
use_overlapped: bool = True,
|
||||||
|
routed_input_transform: torch.nn.Module | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._shared_experts = shared_experts
|
self._shared_experts = shared_experts
|
||||||
|
self._routed_input_transform = routed_input_transform
|
||||||
|
|
||||||
# Disable shared expert overlap if:
|
# Disable shared expert overlap if:
|
||||||
# - we are using eplb with non-default backend, because of correctness issues
|
# - we are using eplb with non-default backend, because of correctness issues
|
||||||
@@ -56,6 +58,26 @@ class SharedFusedMoE(FusedMoE):
|
|||||||
def is_internal_router(self) -> bool:
|
def is_internal_router(self) -> bool:
|
||||||
return self.gate is not None
|
return self.gate is not None
|
||||||
|
|
||||||
|
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply transform for routed experts (e.g., latent projection).
|
||||||
|
|
||||||
|
This is called by FusedMoE.forward_native. The original hidden_states
|
||||||
|
is saved separately so shared experts get [S, hidden_size] while
|
||||||
|
routed experts get the transformed [S, moe_latent_size].
|
||||||
|
|
||||||
|
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
|
||||||
|
moved inside SharedFusedMoE to all-reduce on the smaller latent
|
||||||
|
dimension.
|
||||||
|
"""
|
||||||
|
if self._routed_input_transform is not None:
|
||||||
|
result = self._routed_input_transform(hidden_states)
|
||||||
|
# ReplicatedLinear returns (output, extra_bias) tuple.
|
||||||
|
# We only need the output tensor; extra_bias is not used here.
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
return result[0]
|
||||||
|
return result
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@@ -188,10 +188,29 @@ class NemotronHMoE(nn.Module):
|
|||||||
prefix=f"{prefix}.shared_experts",
|
prefix=f"{prefix}.shared_experts",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_latent_moe:
|
||||||
|
self.fc1_latent_proj = ReplicatedLinear(
|
||||||
|
input_size=config.hidden_size,
|
||||||
|
output_size=self.moe_hidden_size,
|
||||||
|
bias=config.mlp_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
disable_tp=self.is_sequence_parallel,
|
||||||
|
prefix=f"{prefix}.fc1_latent_proj",
|
||||||
|
)
|
||||||
|
self.fc2_latent_proj = ReplicatedLinear(
|
||||||
|
input_size=self.moe_hidden_size,
|
||||||
|
output_size=config.hidden_size,
|
||||||
|
bias=config.mlp_bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
disable_tp=self.is_sequence_parallel,
|
||||||
|
prefix=f"{prefix}.fc2_latent_proj",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.fc1_latent_proj = None
|
||||||
|
self.fc2_latent_proj = None
|
||||||
|
|
||||||
self.experts = SharedFusedMoE(
|
self.experts = SharedFusedMoE(
|
||||||
# TODO: make it possible for shared experts to have
|
shared_experts=self.shared_experts,
|
||||||
# different input in SharedFusedMoE
|
|
||||||
shared_experts=self.shared_experts if not self.use_latent_moe else None,
|
|
||||||
num_experts=config.n_routed_experts,
|
num_experts=config.n_routed_experts,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=self.moe_hidden_size,
|
hidden_size=self.moe_hidden_size,
|
||||||
@@ -211,30 +230,9 @@ class NemotronHMoE(nn.Module):
|
|||||||
num_redundant_experts=self.n_redundant_experts,
|
num_redundant_experts=self.n_redundant_experts,
|
||||||
is_sequence_parallel=self.is_sequence_parallel,
|
is_sequence_parallel=self.is_sequence_parallel,
|
||||||
router_logits_dtype=router_logits_dtype,
|
router_logits_dtype=router_logits_dtype,
|
||||||
|
routed_input_transform=self.fc1_latent_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_latent_moe:
|
|
||||||
self.fc1_latent_proj = ReplicatedLinear(
|
|
||||||
input_size=config.hidden_size,
|
|
||||||
output_size=self.moe_hidden_size,
|
|
||||||
bias=config.mlp_bias,
|
|
||||||
quant_config=quant_config,
|
|
||||||
disable_tp=self.is_sequence_parallel,
|
|
||||||
prefix=f"{prefix}.fc1_latent_proj",
|
|
||||||
)
|
|
||||||
self.fc2_latent_proj = ReplicatedLinear(
|
|
||||||
input_size=self.moe_hidden_size,
|
|
||||||
output_size=config.hidden_size,
|
|
||||||
bias=config.mlp_bias,
|
|
||||||
quant_config=quant_config,
|
|
||||||
disable_tp=self.is_sequence_parallel,
|
|
||||||
prefix=f"{prefix}.fc2_latent_proj",
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.fc1_latent_proj = None
|
|
||||||
self.fc2_latent_proj = None
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
@@ -244,38 +242,28 @@ class NemotronHMoE(nn.Module):
|
|||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
|
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
|
||||||
shared_output = None
|
|
||||||
if self.use_latent_moe:
|
|
||||||
if self.shared_experts is not None:
|
|
||||||
shared_output = self.shared_experts(hidden_states)
|
|
||||||
hidden_states, _ = self.fc1_latent_proj(hidden_states)
|
|
||||||
|
|
||||||
fused_moe_out = self.experts(
|
# SharedFusedMoE handles:
|
||||||
|
# - shared experts (with original hidden_states)
|
||||||
|
# - routed_input_transform (fc1_latent_proj) for latent MoE
|
||||||
|
# - multistream parallelism between shared and routed experts
|
||||||
|
shared_output, final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states, router_logits=router_logits
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_latent_moe:
|
|
||||||
_, final_hidden_states = fused_moe_out
|
|
||||||
else:
|
|
||||||
shared_output, final_hidden_states = fused_moe_out
|
|
||||||
|
|
||||||
# Fix FP16 overflow
|
# Fix FP16 overflow
|
||||||
# See DeepseekV2DecoderLayer for more details.
|
# See DeepseekV2DecoderLayer for more details.
|
||||||
if hidden_states.dtype != torch.float16:
|
if hidden_states.dtype != torch.float16:
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
elif self.shared_experts is not None:
|
elif self.shared_experts is not None:
|
||||||
assert shared_output is not None
|
|
||||||
shared_output *= 1.0 / self.routed_scaling_factor
|
shared_output *= 1.0 / self.routed_scaling_factor
|
||||||
|
|
||||||
# TODO: currently latent up_proj is done before all-reduce for simplicity.
|
# TODO: See SharedFusedMoE.apply_routed_input_transform
|
||||||
# if and when shared experts will be part of SharedFusedMoE,
|
# for bandwidth optimization
|
||||||
# we should do the up_proj after all-reduce,
|
|
||||||
# to have the all-reduce in the smaller latent dimension.
|
|
||||||
if self.use_latent_moe:
|
if self.use_latent_moe:
|
||||||
final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states)
|
final_hidden_states, _ = self.fc2_latent_proj(final_hidden_states)
|
||||||
|
|
||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
assert shared_output is not None
|
|
||||||
final_hidden_states += shared_output
|
final_hidden_states += shared_output
|
||||||
|
|
||||||
if self.is_sequence_parallel:
|
if self.is_sequence_parallel:
|
||||||
|
|||||||
Reference in New Issue
Block a user