[MoE][Kernel][Perf] Improve Shared Expert Stream Overlap (#28406)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
committed by
GitHub
parent
4ca5cd5740
commit
69d0e90313
@@ -48,7 +48,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import current_stream, direct_register_custom_op
|
||||
from vllm.utils.torch_utils import (
|
||||
aux_stream,
|
||||
current_stream,
|
||||
direct_register_custom_op,
|
||||
)
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
@@ -331,7 +335,11 @@ class FusedMoE(CustomOp):
|
||||
logger.info_once("Disabling MoE shared_experts cuda stream")
|
||||
self.shared_experts_stream = None
|
||||
else:
|
||||
self.shared_experts_stream = torch.cuda.Stream()
|
||||
# TODO(rob): enable shared expert overlap with non-cuda.
|
||||
# aux_stream() returns None on non-cuda platforms.
|
||||
self.shared_experts_stream = aux_stream()
|
||||
if self.shared_experts_stream is not None:
|
||||
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@@ -1606,7 +1614,9 @@ class FusedMoE(CustomOp):
|
||||
if has_separate_shared_experts:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
|
||||
shared_output = self.shared_experts(staged_hidden_states)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
@@ -1684,13 +1694,34 @@ class FusedMoE(CustomOp):
|
||||
|
||||
use_chunked_impl = self.use_dp_chunking
|
||||
|
||||
if (
|
||||
use_shared_experts_stream = (
|
||||
has_separate_shared_experts
|
||||
and not use_chunked_impl
|
||||
and self.shared_experts_stream is not None
|
||||
):
|
||||
# Start the separate shared experts stream here since we want
|
||||
# to run in parallel with the router/gate (next op below)
|
||||
and (
|
||||
hidden_states.shape[0]
|
||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
if use_shared_experts_stream:
|
||||
assert self.shared_experts_stream is not None
|
||||
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = hidden_states.clone()
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: We dont need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
|
||||
# Mark sync start point for the separate shared experts
|
||||
# stream here since we want to run in parallel with the
|
||||
# router/gate (next op below)
|
||||
assert self.shared_experts_stream is not None
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
# If router/gate provided, then apply it here.
|
||||
@@ -1709,33 +1740,6 @@ class FusedMoE(CustomOp):
|
||||
self.quant_method, FusedMoEModularMethod
|
||||
)
|
||||
|
||||
# If there are shared experts but we are not using a modular kernel, the
|
||||
# shared experts must be called here
|
||||
if has_separate_shared_experts:
|
||||
assert self.shared_experts is not None
|
||||
|
||||
if self.shared_experts_stream is not None:
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = hidden_states.clone()
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
# Run shared experts in parallel on a separate stream
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
shared_output = self.shared_experts(hidden_states_clone)
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: we dont need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
|
||||
else:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
ctx = get_forward_context()
|
||||
sp_ctx = (
|
||||
ctx.dp_metadata.sp_local_sizes(self.sp_size)
|
||||
@@ -1776,12 +1780,21 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
|
||||
# Wait for the parallel shared experts stream to finish here
|
||||
if self.shared_experts_stream is not None:
|
||||
if use_shared_experts_stream:
|
||||
# Run shared experts in parallel on a separate stream
|
||||
# NOTE: We start the separate stream here and mark the
|
||||
# sync end point immediately after it is done. This is
|
||||
# important to avoid excessive stream allocations by the cuda
|
||||
# graph replay later.
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
# Note that hidden_states clone() is necessary here to avoid
|
||||
# conflict with the main stream
|
||||
shared_output = self.shared_experts(hidden_states_clone)
|
||||
current_stream().wait_stream(self.shared_experts_stream)
|
||||
else:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
|
||||
Reference in New Issue
Block a user