[MoE Refactor] DefaultMoERunner simplifcation (#33049)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -504,6 +504,8 @@ class FusedMoE(CustomOp):
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.activation = MoEActivation.from_str(activation)
|
||||
|
||||
# TODO(bnell): we should not have to create a router if the kernel is
|
||||
# monolithic.
|
||||
self.router = create_fused_moe_router(
|
||||
top_k=top_k,
|
||||
global_num_experts=self.global_num_experts,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -82,8 +83,21 @@ def _moe_forward(
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
return layer.runner.forward_impl(
|
||||
layer, hidden_states, router_logits, shared_experts_input
|
||||
runner = layer.runner
|
||||
with runner._sequence_parallel_context():
|
||||
if runner.use_dp_chunking:
|
||||
return runner.forward_impl_chunked(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
)
|
||||
else:
|
||||
return runner.forward_impl(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
@@ -105,8 +119,21 @@ def _moe_forward_shared(
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
return layer.runner.forward_impl(
|
||||
layer, hidden_states, router_logits, shared_experts_input
|
||||
runner = layer.runner
|
||||
with runner._sequence_parallel_context():
|
||||
if runner.use_dp_chunking:
|
||||
return runner.forward_impl_chunked(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
)
|
||||
else:
|
||||
return runner.forward_impl(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
@@ -191,10 +218,17 @@ class DefaultMoERunner(MoERunner):
|
||||
self.reduce_results = reduce_results
|
||||
self.enable_dbo = enable_dbo
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
# TODO(bnell) rename these?
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
self.batched_router_logits: torch.Tensor | None = None
|
||||
self._maybe_init_dp_chunking()
|
||||
|
||||
# Allow disabling of the separate shared experts stream for
|
||||
# debug purposes.
|
||||
# TODO: Remove this after more extensive testings with TP/DP
|
||||
# and other execution modes
|
||||
self.use_shared_experts_stream = False
|
||||
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
|
||||
logger.debug_once("Disabling MoE shared_experts cuda stream", scope="local")
|
||||
self.shared_experts_stream = None
|
||||
@@ -210,23 +244,20 @@ class DefaultMoERunner(MoERunner):
|
||||
# Needed for string -> FusedMoE layer lookup in custom ops.
|
||||
self.layer_name = layer.layer_name
|
||||
|
||||
self.moe_forward = self._select_forward(layer)
|
||||
|
||||
def _select_forward(self, layer: torch.nn.Module) -> Callable:
|
||||
if current_platform.is_tpu() or current_platform.is_cpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
# will switch to using the moe_forward custom op.
|
||||
# Note: CPU doesn't require wrapped forward_impl.
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = _moe_forward
|
||||
else:
|
||||
self.moe_forward = _moe_forward_shared
|
||||
else:
|
||||
if self.shared_experts is None:
|
||||
self.moe_forward = torch.ops.vllm.moe_forward
|
||||
else:
|
||||
self.moe_forward = torch.ops.vllm.moe_forward_shared
|
||||
return _moe_forward if self.shared_experts is None else _moe_forward_shared
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: torch.Tensor | None = None
|
||||
self.batched_router_logits: torch.Tensor | None = None
|
||||
return (
|
||||
torch.ops.vllm.moe_forward
|
||||
if self.shared_experts is None
|
||||
else torch.ops.vllm.moe_forward_shared
|
||||
)
|
||||
|
||||
@property
|
||||
def use_dp_chunking(self) -> bool:
|
||||
@@ -241,22 +272,8 @@ class DefaultMoERunner(MoERunner):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
shared_input: torch.Tensor | None,
|
||||
has_separate_shared_experts: bool,
|
||||
use_chunked_impl: bool,
|
||||
) -> tuple[bool, torch.Tensor | None]:
|
||||
use_shared_experts_stream = (
|
||||
current_platform.is_cuda()
|
||||
and has_separate_shared_experts
|
||||
and not use_chunked_impl
|
||||
and self.shared_experts_stream is not None
|
||||
and (
|
||||
hidden_states.shape[0]
|
||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
shared_experts_input: torch.Tensor | None = None
|
||||
if use_shared_experts_stream:
|
||||
):
|
||||
if self.use_shared_experts_stream:
|
||||
assert self.shared_experts_stream is not None
|
||||
assert self.moe_config.disable_inplace
|
||||
|
||||
@@ -278,12 +295,11 @@ class DefaultMoERunner(MoERunner):
|
||||
assert self.shared_experts_stream is not None
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
return use_shared_experts_stream, shared_experts_input
|
||||
|
||||
def ensure_dp_chunking_init(self):
|
||||
if not self.use_dp_chunking or self.batched_hidden_states is not None:
|
||||
def _maybe_init_dp_chunking(self):
|
||||
if not self.use_dp_chunking:
|
||||
return
|
||||
|
||||
assert self.batched_hidden_states is None
|
||||
states_shape: tuple[int, ...]
|
||||
logits_shape: tuple[int, ...]
|
||||
|
||||
@@ -309,6 +325,38 @@ class DefaultMoERunner(MoERunner):
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def has_separate_shared_experts(self) -> bool:
|
||||
return (
|
||||
not self.quant_method.mk_owns_shared_expert
|
||||
and self.shared_experts is not None
|
||||
)
|
||||
|
||||
def _apply_shared_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
allow_streaming: bool = False,
|
||||
) -> torch.Tensor | None:
|
||||
shared_output: torch.Tensor | None = None
|
||||
if self.has_separate_shared_experts:
|
||||
assert self.shared_experts is not None
|
||||
|
||||
if self.use_shared_experts_stream and allow_streaming:
|
||||
# Run shared experts in parallel on a separate stream
|
||||
# NOTE: We start the separate stream here and mark the
|
||||
# sync end point immediately after it is done. This is
|
||||
# important to avoid excessive stream allocations by the cuda
|
||||
# graph replay later.
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
# Note that hidden_states clone() is necessary here to avoid
|
||||
# conflict with the main stream
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
current_stream().wait_stream(self.shared_experts_stream)
|
||||
else:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
return shared_output
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
"""
|
||||
The shared_experts are typically computed using the RowParallelLinear
|
||||
@@ -322,7 +370,6 @@ class DefaultMoERunner(MoERunner):
|
||||
Therefore it is required that we reduce the shared_experts output
|
||||
early.
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
return (
|
||||
self.quant_method.moe_kernel is not None
|
||||
and self.quant_method.moe_kernel.output_is_reduced()
|
||||
@@ -357,7 +404,7 @@ class DefaultMoERunner(MoERunner):
|
||||
return result
|
||||
return hidden_states
|
||||
|
||||
def _reduce_output(
|
||||
def _maybe_reduce_output(
|
||||
self,
|
||||
states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
trunc_sizes: list[int],
|
||||
@@ -397,23 +444,16 @@ class DefaultMoERunner(MoERunner):
|
||||
return "from_forward_context"
|
||||
return self.layer_name
|
||||
|
||||
def forward(
|
||||
def _maybe_pad_hidden_states(
|
||||
self,
|
||||
original_hidden_states: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# For latent MoE: save ORIGINAL hidden_states before transform
|
||||
# (shared_experts need original dimension, routed experts use transformed)
|
||||
if self.shared_experts is not None:
|
||||
original_hidden_states = hidden_states
|
||||
original_hidden_dim = hidden_states.shape[-1]
|
||||
else:
|
||||
original_hidden_states = None
|
||||
|
||||
# Apply transform for routed experts (e.g., latent projection for latent MoE)
|
||||
hidden_states = self.apply_routed_input_transform(hidden_states)
|
||||
|
||||
# This is the dimension after transform (for routed expert output slicing)
|
||||
) -> tuple[torch.Tensor, list[int]]:
|
||||
original_hidden_dim = (
|
||||
original_hidden_states.shape[-1]
|
||||
if original_hidden_states is not None
|
||||
else 0
|
||||
)
|
||||
transformed_hidden_dim = hidden_states.shape[-1]
|
||||
if (
|
||||
not self.quant_method.skip_forward_padding
|
||||
@@ -426,6 +466,192 @@ class DefaultMoERunner(MoERunner):
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim]
|
||||
else:
|
||||
orig_hidden_dims = [transformed_hidden_dim]
|
||||
|
||||
return hidden_states, orig_hidden_dims
|
||||
|
||||
def _apply_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_input: torch.Tensor | None,
|
||||
run_shared_experts_before: bool = True,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
||||
shared_input = shared_input if shared_input is not None else hidden_states
|
||||
shared_output: torch.Tensor | None = None
|
||||
|
||||
# Run this before quant_method to avoid inplace issues.
|
||||
if run_shared_experts_before:
|
||||
shared_output = self._apply_shared_experts(shared_input, False)
|
||||
|
||||
if self.quant_method.is_monolithic:
|
||||
result = self.quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
result = self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_input,
|
||||
)
|
||||
|
||||
if isinstance(result, tuple):
|
||||
assert shared_output is None
|
||||
shared_output, hidden_states = result
|
||||
else:
|
||||
hidden_states = result
|
||||
|
||||
if not run_shared_experts_before and self.has_separate_shared_experts:
|
||||
assert shared_output is None
|
||||
shared_output = self._apply_shared_experts(shared_input, True)
|
||||
|
||||
return shared_output, hidden_states
|
||||
|
||||
def _sequence_parallel_context(self):
|
||||
ctx = get_forward_context()
|
||||
return (
|
||||
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
|
||||
if ctx.dp_metadata
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
def _allocate_dp_chunking_outputs(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
||||
assert self.use_dp_chunking
|
||||
|
||||
# Assert the inputs are of the proper type and shape.
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
|
||||
assert self.batched_hidden_states.dtype == hidden_states.dtype, (
|
||||
f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}"
|
||||
)
|
||||
assert self.batched_router_logits.dtype == router_logits.dtype, (
|
||||
f"{self.batched_router_logits.dtype} == {router_logits.dtype}"
|
||||
)
|
||||
|
||||
# Check size compatibility.
|
||||
assert self.batched_hidden_states.size(-1) == hidden_states.size(-1)
|
||||
assert self.batched_router_logits.size(-1) == router_logits.size(-1)
|
||||
|
||||
final_fused_hidden_states = torch.empty_like(hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
final_shared_hidden_states = torch.empty_like(hidden_states)
|
||||
else:
|
||||
final_shared_hidden_states = None
|
||||
|
||||
return final_shared_hidden_states, final_fused_hidden_states
|
||||
|
||||
def _maybe_gate(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# If router/gate provided, then apply it here.
|
||||
# (Note: This code runs only when "overlapped mode" is on to allow
|
||||
# parallel execution of shared experts with the FusedMoE via
|
||||
# separate cuda stream)
|
||||
if self.gate is not None:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
return router_logits
|
||||
|
||||
@property
|
||||
def do_naive_dispatch_combine(self) -> bool:
|
||||
return (
|
||||
self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
|
||||
)
|
||||
|
||||
def _maybe_dispatch(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
|
||||
# router logits to all experts.
|
||||
# NOTE: this will be removed once all kernels are migrated into the
|
||||
# MoEKernel framework.
|
||||
if self.do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
self.moe_config.is_sequence_parallel,
|
||||
)
|
||||
|
||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||
# we should modify All2AllManager abstraction to better support PCP.
|
||||
if self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
)
|
||||
router_logits = get_pcp_group().all_gather(
|
||||
router_logits,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def _maybe_combine(
|
||||
self,
|
||||
shared_output: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if self.do_naive_dispatch_combine:
|
||||
hidden_states = get_ep_group().combine(
|
||||
hidden_states, self.moe_config.is_sequence_parallel
|
||||
)
|
||||
|
||||
if self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().reduce_scatter(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
)
|
||||
# need RS for shared_output?
|
||||
|
||||
if self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
return shared_output, hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# For latent MoE: save ORIGINAL hidden_states before transform
|
||||
# (shared_experts need original dimension, routed experts use transformed)
|
||||
if self.shared_experts is not None:
|
||||
original_hidden_states = hidden_states
|
||||
else:
|
||||
original_hidden_states = None
|
||||
|
||||
# Apply transform for routed experts (e.g., latent projection for latent MoE)
|
||||
hidden_states = self.apply_routed_input_transform(hidden_states)
|
||||
|
||||
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
|
||||
original_hidden_states,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
fused_output = self.moe_forward(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
@@ -433,126 +659,41 @@ class DefaultMoERunner(MoERunner):
|
||||
self._encode_layer_name(),
|
||||
)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
orig_hidden_dims = [original_hidden_dim, transformed_hidden_dim]
|
||||
else:
|
||||
orig_hidden_dims = [transformed_hidden_dim]
|
||||
return self._maybe_reduce_output(fused_output, og_hidden_dims)
|
||||
|
||||
return self._reduce_output(fused_output, orig_hidden_dims)
|
||||
def _slice_and_copy_input(
|
||||
self,
|
||||
out_slice: torch.Tensor,
|
||||
orig: torch.Tensor | None,
|
||||
start: int,
|
||||
end: int,
|
||||
) -> torch.Tensor:
|
||||
assert orig is not None
|
||||
slice_size = end - start
|
||||
orig_slice = orig[start:end, :]
|
||||
if self.enable_dbo:
|
||||
assert out_slice.dim() == 3
|
||||
batch_buffer_idx = dbo_current_ubatch_id()
|
||||
out_slice = out_slice[batch_buffer_idx, :]
|
||||
|
||||
assert out_slice.size(0) >= slice_size
|
||||
out_slice = out_slice[:slice_size, :]
|
||||
out_slice.copy_(orig_slice, non_blocking=True)
|
||||
return out_slice
|
||||
|
||||
def forward_impl_chunked(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor,
|
||||
full_shared_input: torch.Tensor | None,
|
||||
has_separate_shared_experts: bool,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
assert self.batched_hidden_states.dtype == full_hidden_states.dtype, (
|
||||
f"{self.batched_hidden_states.dtype} == {full_hidden_states.dtype}"
|
||||
)
|
||||
assert self.batched_router_logits.dtype == full_router_logits.dtype, (
|
||||
f"{self.batched_router_logits.dtype} == {full_router_logits.dtype}"
|
||||
)
|
||||
# Check size compatibility.
|
||||
assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)
|
||||
assert self.batched_router_logits.size(-1) == full_router_logits.size(-1)
|
||||
# Gate overlap not supported when chunking is enabled. Run the
|
||||
# gate first.
|
||||
router_logits = self._maybe_gate(hidden_states, router_logits)
|
||||
|
||||
# TODO(bnell): Fix shared_expert_inputs w/chunking.
|
||||
# assert shared_input is None, (
|
||||
# "Routed input transform is not currently supported with DP chunking."
|
||||
# )
|
||||
|
||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
full_shared_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
|
||||
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
||||
chunk_size = chunk_end - chunk_start
|
||||
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
||||
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
||||
shared_input = (
|
||||
full_shared_input[chunk_start:chunk_end, :]
|
||||
if full_shared_input is not None
|
||||
else None
|
||||
)
|
||||
|
||||
assert self.batched_hidden_states is not None
|
||||
assert self.batched_router_logits is not None
|
||||
# This is only true when DBO has been enabled in the config.
|
||||
# Both tensors will have an outer dimension for the ubatch id
|
||||
if self.batched_hidden_states.dim() == 3:
|
||||
assert self.batched_router_logits.dim() == 3
|
||||
batch_buffer_idx = dbo_current_ubatch_id()
|
||||
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
|
||||
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
|
||||
else:
|
||||
batched_hidden_states = self.batched_hidden_states
|
||||
batched_router_logits = self.batched_router_logits
|
||||
|
||||
assert (
|
||||
batched_hidden_states.size(0) # type: ignore
|
||||
>= chunk_size
|
||||
)
|
||||
assert (
|
||||
batched_router_logits.size(0) # type: ignore
|
||||
>= chunk_size
|
||||
)
|
||||
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
|
||||
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
|
||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||
|
||||
shared_input = (
|
||||
shared_input if shared_input is not None else staged_hidden_states
|
||||
)
|
||||
|
||||
# Matrix multiply.
|
||||
if self.quant_method.is_monolithic:
|
||||
assert has_separate_shared_experts or self.shared_experts is None
|
||||
final_hidden_states = self.quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=staged_hidden_states,
|
||||
router_logits=staged_router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=staged_hidden_states,
|
||||
router_logits=staged_router_logits,
|
||||
)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=staged_hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_input,
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
|
||||
shared_output = self.shared_experts(shared_input)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
|
||||
if not skip_result_store:
|
||||
if self.shared_experts is None:
|
||||
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states, non_blocking=True
|
||||
)
|
||||
else:
|
||||
full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states[0], non_blocking=True
|
||||
)
|
||||
full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
final_hidden_states[1], non_blocking=True
|
||||
final_shared_hidden_states, final_fused_hidden_states = (
|
||||
self._allocate_dp_chunking_outputs(hidden_states, router_logits)
|
||||
)
|
||||
|
||||
ctx = get_forward_context()
|
||||
@@ -567,7 +708,7 @@ class DefaultMoERunner(MoERunner):
|
||||
max_tokens_across_dispatchers, self.moe_config.sp_size
|
||||
)
|
||||
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
num_tokens = hidden_states.size(0)
|
||||
for chunk_idx, chunk_start_ in enumerate(
|
||||
range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
|
||||
):
|
||||
@@ -578,17 +719,55 @@ class DefaultMoERunner(MoERunner):
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
with ctx.dp_metadata.chunked_sizes(
|
||||
chunk_sizes = ctx.dp_metadata.chunked_sizes(
|
||||
self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
|
||||
):
|
||||
process_chunk(
|
||||
chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens
|
||||
)
|
||||
with chunk_sizes:
|
||||
hidden_states_chunk = self._slice_and_copy_input(
|
||||
self.batched_hidden_states,
|
||||
hidden_states,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
)
|
||||
|
||||
router_logits_chunk = self._slice_and_copy_input(
|
||||
self.batched_router_logits,
|
||||
router_logits,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
)
|
||||
|
||||
shared_input_chunk = (
|
||||
shared_input[chunk_start:chunk_end, :]
|
||||
if shared_input is not None
|
||||
else None
|
||||
)
|
||||
|
||||
shared_output_chunk, hidden_states_chunk = self._apply_quant_method(
|
||||
layer=layer,
|
||||
hidden_states=hidden_states_chunk,
|
||||
router_logits=router_logits_chunk,
|
||||
shared_input=shared_input_chunk,
|
||||
)
|
||||
|
||||
# Store outputs
|
||||
# TODO(bnell): document when chunk_start >= num_tokens
|
||||
if chunk_start < num_tokens:
|
||||
final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
hidden_states_chunk, non_blocking=True
|
||||
)
|
||||
if self.shared_experts is not None:
|
||||
assert shared_output_chunk is not None
|
||||
assert final_shared_hidden_states is not None
|
||||
final_shared_hidden_states[chunk_start:chunk_end, :].copy_(
|
||||
shared_output_chunk, non_blocking=True
|
||||
)
|
||||
|
||||
if self.shared_experts is None:
|
||||
return full_fused_final_hidden_states
|
||||
return final_fused_hidden_states
|
||||
else:
|
||||
return (full_shared_final_hidden_states, full_fused_final_hidden_states)
|
||||
assert final_shared_hidden_states is not None
|
||||
return (final_shared_hidden_states, final_fused_hidden_states)
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
@@ -597,148 +776,51 @@ class DefaultMoERunner(MoERunner):
|
||||
router_logits: torch.Tensor,
|
||||
shared_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.ensure_dp_chunking_init()
|
||||
|
||||
has_separate_shared_experts = (
|
||||
not self.quant_method.mk_owns_shared_expert
|
||||
and self.shared_experts is not None
|
||||
self.use_shared_experts_stream = (
|
||||
current_platform.is_cuda()
|
||||
and self.has_separate_shared_experts
|
||||
and not self.use_dp_chunking
|
||||
and self.shared_experts_stream is not None
|
||||
and (
|
||||
hidden_states.shape[0]
|
||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
use_chunked_impl = self.use_dp_chunking
|
||||
# Check if we need to run shared experts before matrix multiply because
|
||||
# matrix multiply may modify the hidden_states.
|
||||
run_shared_experts_before = (
|
||||
self.has_separate_shared_experts and not self.use_shared_experts_stream
|
||||
)
|
||||
|
||||
use_shared_experts_stream, shared_experts_input = (
|
||||
# The shared experts stream must be set up before calling the gate so they
|
||||
# can be overlapped.
|
||||
if not run_shared_experts_before:
|
||||
self._maybe_setup_shared_experts_stream(
|
||||
hidden_states,
|
||||
shared_input,
|
||||
has_separate_shared_experts,
|
||||
use_chunked_impl,
|
||||
)
|
||||
)
|
||||
|
||||
# If router/gate provided, then apply it here.
|
||||
# (Note: This code runs only when "overlapped mode" is on to allow
|
||||
# parallel execution of shared experts with the FusedMoE via
|
||||
# separate cuda stream)
|
||||
if self.gate is not None:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
router_logits = self._maybe_gate(hidden_states, router_logits)
|
||||
|
||||
if use_chunked_impl:
|
||||
return self.forward_impl_chunked(
|
||||
# TODO(bnell): parts of the dispatch/combine steps will go away once
|
||||
# #32567 lands and the remaining kernels are made MKs. The PCP
|
||||
# code will probably remain
|
||||
hidden_states, router_logits = self._maybe_dispatch(
|
||||
layer,
|
||||
hidden_states,
|
||||
router_logits,
|
||||
shared_input,
|
||||
has_separate_shared_experts,
|
||||
)
|
||||
|
||||
# NOTE(rob): once we finish migrating all the quant methods to use
|
||||
# MKs, we can remove the naive dispatch/combine path from here.
|
||||
do_naive_dispatch_combine = (
|
||||
self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
|
||||
)
|
||||
|
||||
ctx = get_forward_context()
|
||||
sp_ctx = (
|
||||
ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
|
||||
if ctx.dp_metadata
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with sp_ctx:
|
||||
# Run shared experts before matrix multiply.
|
||||
# because matrix multiply maybe modify the hidden_states.
|
||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||
assert self.shared_experts is not None
|
||||
shared_input = (
|
||||
shared_input if shared_input is not None else hidden_states
|
||||
)
|
||||
shared_output = self.shared_experts(shared_input)
|
||||
|
||||
# For naive dispatch/combine Dp/Ep, dispatch the hidden states and
|
||||
# router logits to all experts.
|
||||
# NOTE: this will be removed once all kernels are migrated into the
|
||||
# MoEKernel framework.
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
self.moe_config.is_sequence_parallel,
|
||||
)
|
||||
|
||||
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
|
||||
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
|
||||
# we should modify All2AllManager abstract to better support PCP.
|
||||
if self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
hidden_states,
|
||||
dim=0,
|
||||
)
|
||||
router_logits = get_pcp_group().all_gather(
|
||||
router_logits,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Matrix multiply.
|
||||
if self.quant_method.is_monolithic:
|
||||
final_hidden_states = self.quant_method.apply_monolithic(
|
||||
shared_output, hidden_states = self._apply_quant_method(
|
||||
layer=layer,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = self.router.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
shared_input=shared_input,
|
||||
run_shared_experts_before=run_shared_experts_before,
|
||||
)
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
shared_experts_input=shared_input,
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
assert self.shared_experts is not None
|
||||
|
||||
if use_shared_experts_stream:
|
||||
# Run shared experts in parallel on a separate stream
|
||||
# NOTE: We start the separate stream here and mark the
|
||||
# sync end point immediately after it is done. This is
|
||||
# important to avoid excessive stream allocations by the cuda
|
||||
# graph replay later.
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
# Note that hidden_states clone() is necessary here to avoid
|
||||
# conflict with the main stream
|
||||
shared_output = self.shared_experts(shared_experts_input)
|
||||
current_stream().wait_stream(self.shared_experts_stream)
|
||||
|
||||
final_hidden_states = (
|
||||
return self._maybe_combine(
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
def combine_output(states: torch.Tensor) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine:
|
||||
states = get_ep_group().combine(
|
||||
states, self.moe_config.is_sequence_parallel
|
||||
)
|
||||
|
||||
if self.moe_config.pcp_size > 1:
|
||||
states = get_pcp_group().reduce_scatter(
|
||||
states,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return states
|
||||
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
final_hidden_states[0],
|
||||
combine_output(final_hidden_states[1]),
|
||||
)
|
||||
else:
|
||||
return combine_output(final_hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user