[MoE Refactor][13/N] Convert FI to Use PFNoEP (#31533)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -21,7 +21,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
count_expert_num_tokens,
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_enabled,
|
||||
@@ -682,14 +681,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
shared_experts_stream: torch.cuda.Stream | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
self.shared_experts_stream = shared_experts_stream
|
||||
|
||||
# prefer an explicit FusedMoEParallelConfig when available (from
|
||||
# FusedMoE layers / tests).
|
||||
@@ -904,34 +901,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
|
||||
)
|
||||
|
||||
def _maybe_setup_shared_experts_stream(
|
||||
self, hidden_states: torch.Tensor
|
||||
) -> tuple[bool, torch.Tensor | None]:
|
||||
# decide whether to run shared experts on a separate CUDA stream to
|
||||
# overlap with the main fused MoE kernel.
|
||||
use_shared_experts_stream = (
|
||||
self.shared_experts is not None
|
||||
and self.shared_experts_stream is not None
|
||||
and hidden_states.is_cuda
|
||||
and (
|
||||
hidden_states.shape[0]
|
||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states_clone: torch.Tensor | None = None
|
||||
if use_shared_experts_stream and self.shared_experts_stream is not None:
|
||||
# TODO: Optimize this (complicated)
|
||||
# Note: this clone adds overhead but is required
|
||||
# for correctness with multiple CUDA streams and CUDA graph capture.
|
||||
hidden_states_clone = hidden_states.clone()
|
||||
# record that the clone will be used by the separate stream so its
|
||||
# lifetime is correctly tracked.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
return use_shared_experts_stream, hidden_states_clone
|
||||
|
||||
def _prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1119,30 +1088,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
hidden_states_clone: torch.Tensor | None = None,
|
||||
use_shared_experts_stream: bool = False,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The _finalize method is a wrapper around self.prepare_finalize.finalize
|
||||
that handles DBO, async and shared expert overlap.
|
||||
"""
|
||||
|
||||
def maybe_run_shared_experts() -> torch.Tensor | None:
|
||||
if self.shared_experts is None:
|
||||
return None
|
||||
|
||||
if (
|
||||
not use_shared_experts_stream
|
||||
or self.shared_experts_stream is not None
|
||||
and (not hidden_states.is_cuda or not torch.cuda.is_available())
|
||||
):
|
||||
# fall back to running on the current stream
|
||||
return self.shared_experts(hidden_states)
|
||||
|
||||
assert hidden_states_clone is not None
|
||||
# launch shared experts on the dedicated stream.
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
return self.shared_experts(hidden_states_clone)
|
||||
shared_output: torch.Tensor | None = None
|
||||
|
||||
if not self.prepare_finalize.supports_async():
|
||||
assert not dbo_enabled()
|
||||
@@ -1155,7 +1106,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
shared_output = maybe_run_shared_experts()
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
finalize_ret = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
@@ -1165,8 +1117,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
)
|
||||
|
||||
shared_output = maybe_run_shared_experts()
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
@@ -1189,28 +1141,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
receiver()
|
||||
|
||||
self._wait_for_shared_experts_stream(hidden_states, use_shared_experts_stream)
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
else:
|
||||
assert shared_output is not None
|
||||
return shared_output, output
|
||||
|
||||
def _wait_for_shared_experts_stream(
|
||||
self, hidden_states: torch.Tensor, use_shared_experts_stream: bool
|
||||
) -> None:
|
||||
# ensure that any work enqueued on the shared_experts_stream is
|
||||
# completed before the shared_output tensor is consumed
|
||||
if (
|
||||
self.shared_experts is not None
|
||||
and use_shared_experts_stream
|
||||
and self.shared_experts_stream is not None
|
||||
and hidden_states.is_cuda
|
||||
and current_platform.is_cuda()
|
||||
):
|
||||
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1257,10 +1193,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
use_shared_experts_stream, hidden_states_clone = (
|
||||
self._maybe_setup_shared_experts_stream(hidden_states)
|
||||
)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
@@ -1297,6 +1229,4 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
hidden_states_clone=hidden_states_clone,
|
||||
use_shared_experts_stream=use_shared_experts_stream,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user