[MoE] Enable Shared/Routed Overlap For Latent MoE (Nemotron-H) (#32790)
Signed-off-by: dafrimi <dafrimi@nvidia.com>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from contextlib import nullcontext
|
||||
from collections.abc import Callable, Generator, Iterable
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from enum import Enum
|
||||
from typing import Literal, cast, get_args, overload
|
||||
|
||||
@@ -351,6 +351,10 @@ class FusedMoE(CustomOp):
|
||||
"Enabled separate cuda stream for MoE shared_experts", scope="local"
|
||||
)
|
||||
|
||||
# For latent MoE: stores original hidden_states before routed_input_transform
|
||||
# so shared_experts can use it for cloning (they need original dimension)
|
||||
self._shared_experts_input: torch.Tensor | None = None
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
@@ -664,6 +668,39 @@ class FusedMoE(CustomOp):
|
||||
def gate(self) -> torch.nn.Module | None:
|
||||
return None
|
||||
|
||||
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Hook to transform hidden_states before passing to routed experts.
|
||||
For latent MoE: transforms [S, hidden_size] → [S, moe_latent_size].
|
||||
The original hidden_states is saved in _shared_experts_input so
|
||||
shared_experts still receive the original [S, hidden_size].
|
||||
|
||||
Override in subclasses (e.g., SharedFusedMoE) for latent MoE.
|
||||
"""
|
||||
return hidden_states
|
||||
|
||||
@contextmanager
|
||||
def _set_shared_experts_input(
|
||||
self, value: torch.Tensor | None
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager to safely set/clear _shared_experts_input."""
|
||||
self._shared_experts_input = value
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._shared_experts_input = None
|
||||
|
||||
def _get_shared_experts_input(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Get input for shared experts.
|
||||
|
||||
For latent MoE: shared_experts need original [S, hidden_size],
|
||||
not the transformed [S, latent_size] used by routed experts.
|
||||
"""
|
||||
return (
|
||||
self._shared_experts_input
|
||||
if self._shared_experts_input is not None
|
||||
else hidden_states
|
||||
)
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
@@ -855,9 +892,11 @@ class FusedMoE(CustomOp):
|
||||
if use_shared_experts_stream:
|
||||
assert self.shared_experts_stream is not None
|
||||
|
||||
shared_experts_input = self._get_shared_experts_input(hidden_states)
|
||||
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = hidden_states.clone()
|
||||
hidden_states_clone = shared_experts_input.clone()
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
@@ -1537,11 +1576,20 @@ class FusedMoE(CustomOp):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
og_hidden_states = hidden_states.shape[-1]
|
||||
if self.hidden_size != og_hidden_states:
|
||||
# For latent MoE: save ORIGINAL hidden_states before transform
|
||||
# (shared_experts need original dimension, routed experts use transformed)
|
||||
original_hidden_states = hidden_states
|
||||
original_hidden_dim = hidden_states.shape[-1]
|
||||
|
||||
# Apply transform for routed experts (e.g., latent projection for latent MoE)
|
||||
hidden_states = self.apply_routed_input_transform(hidden_states)
|
||||
|
||||
# This is the dimension after transform (for routed expert output slicing)
|
||||
transformed_hidden_dim = hidden_states.shape[-1]
|
||||
if self.hidden_size != transformed_hidden_dim:
|
||||
hidden_states = F.pad(
|
||||
hidden_states,
|
||||
(0, self.hidden_size - og_hidden_states),
|
||||
(0, self.hidden_size - transformed_hidden_dim),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
@@ -1576,22 +1624,31 @@ class FusedMoE(CustomOp):
|
||||
fused_output = torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits, encode_layer_name()
|
||||
)
|
||||
return reduce_output(fused_output)[..., :og_hidden_states]
|
||||
return reduce_output(fused_output)[..., :transformed_hidden_dim]
|
||||
else:
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
# Note: CPU doesn't require wrapped forward_impl.
|
||||
shared_output, fused_output = self.forward_impl(
|
||||
hidden_states, router_logits
|
||||
)
|
||||
with self._set_shared_experts_input(original_hidden_states):
|
||||
shared_output, fused_output = self.forward_impl(
|
||||
hidden_states, router_logits
|
||||
)
|
||||
else:
|
||||
# Custom op handles setting/clearing _shared_experts_input internally
|
||||
# We pass original tensor for shared experts (not transformed)
|
||||
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
|
||||
hidden_states, router_logits, encode_layer_name()
|
||||
hidden_states,
|
||||
router_logits,
|
||||
encode_layer_name(),
|
||||
original_hidden_states,
|
||||
)
|
||||
|
||||
# shared_output uses original dimension (before transform)
|
||||
# fused_output uses transformed dimension (after transform)
|
||||
return (
|
||||
reduce_output(shared_output)[..., :og_hidden_states],
|
||||
reduce_output(fused_output)[..., :og_hidden_states],
|
||||
reduce_output(shared_output)[..., :original_hidden_dim],
|
||||
reduce_output(fused_output)[..., :transformed_hidden_dim],
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1831,7 +1888,8 @@ class FusedMoE(CustomOp):
|
||||
# because matrix multiply maybe modify the hidden_states.
|
||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||
assert self.shared_experts is not None
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
shared_input = self._get_shared_experts_input(hidden_states)
|
||||
shared_output = self.shared_experts(shared_input)
|
||||
|
||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||
@@ -2021,19 +2079,34 @@ def moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self = get_layer_from_name(layer_name)
|
||||
assert self.shared_experts is not None
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
# Set here because torch.compile skips forward_native() setup code
|
||||
# and calls this op directly. forward_impl() reads from this var.
|
||||
with self._set_shared_experts_input(shared_experts_input):
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
|
||||
|
||||
def moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
layer_name: str,
|
||||
shared_experts_input: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
# Output shapes:
|
||||
# - fused_out: same as hidden_states (routed experts use transformed size)
|
||||
# - shared_out: same as shared_experts_input if provided, else same as hidden_states
|
||||
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
|
||||
if shared_experts_input is not None:
|
||||
shared_out = torch.empty_like(shared_experts_input)
|
||||
else:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
|
||||
@@ -23,10 +23,12 @@ class SharedFusedMoE(FusedMoE):
|
||||
shared_experts: torch.nn.Module | None,
|
||||
gate: torch.nn.Module | None = None,
|
||||
use_overlapped: bool = True,
|
||||
routed_input_transform: torch.nn.Module | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._shared_experts = shared_experts
|
||||
self._routed_input_transform = routed_input_transform
|
||||
|
||||
# Disable shared expert overlap if:
|
||||
# - we are using eplb with non-default backend, because of correctness issues
|
||||
@@ -56,6 +58,26 @@ class SharedFusedMoE(FusedMoE):
|
||||
def is_internal_router(self) -> bool:
|
||||
return self.gate is not None
|
||||
|
||||
def apply_routed_input_transform(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply transform for routed experts (e.g., latent projection).
|
||||
|
||||
This is called by FusedMoE.forward_native. The original hidden_states
|
||||
is saved separately so shared experts get [S, hidden_size] while
|
||||
routed experts get the transformed [S, moe_latent_size].
|
||||
|
||||
TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
|
||||
moved inside SharedFusedMoE to all-reduce on the smaller latent
|
||||
dimension.
|
||||
"""
|
||||
if self._routed_input_transform is not None:
|
||||
result = self._routed_input_transform(hidden_states)
|
||||
# ReplicatedLinear returns (output, extra_bias) tuple.
|
||||
# We only need the output tensor; extra_bias is not used here.
|
||||
if isinstance(result, tuple):
|
||||
return result[0]
|
||||
return result
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user