[Model Bash][DeepSeekR1] Remove Shared Expert Clone (#34344)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2026-02-19 10:56:14 -05:00
committed by GitHub
parent ee1d25f199
commit 4685a630a2
2 changed files with 11 additions and 13 deletions

View File

@@ -240,24 +240,22 @@ class DefaultMoERunner(MoERunner):
)
)
hidden_states_clone: torch.Tensor | None = None
shared_experts_input: torch.Tensor | None = None
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
assert self.moe_config.disable_inplace
shared_experts_input = (
shared_input if shared_input is not None else hidden_states
)
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = shared_experts_input.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# Record that the shared_experts_input will be used in the
# shared_experts_stream to to avoid gc issue from
# deallocation. For more details:
# https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We don't need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
shared_experts_input.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the
@@ -265,7 +263,7 @@ class DefaultMoERunner(MoERunner):
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream())
return use_shared_experts_stream, hidden_states_clone
return use_shared_experts_stream, shared_experts_input
def ensure_dp_chunking_init(self):
if not self.use_dp_chunking or self.batched_hidden_states is not None:
@@ -584,7 +582,7 @@ class DefaultMoERunner(MoERunner):
use_chunked_impl = self.use_dp_chunking
use_shared_experts_stream, hidden_states_clone = (
use_shared_experts_stream, shared_experts_input = (
self._maybe_setup_shared_experts_stream(
hidden_states,
shared_input,
@@ -726,7 +724,7 @@ class DefaultMoERunner(MoERunner):
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states_clone)
shared_output = self.shared_experts(shared_experts_input)
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (

View File

@@ -175,7 +175,7 @@ class MiniCPMMoE(nn.Module):
)
final_hidden_states = fused_experts(
hidden_states, self.ws, self.w2s, topk_weights, topk_ids, inplace=True
hidden_states, self.ws, self.w2s, topk_weights, topk_ids, inplace=False
)
if self.tp_size > 1: