[Model Bash][DeepSeekR1] Remove Shared Expert Clone (#34344)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -240,24 +240,22 @@ class DefaultMoERunner(MoERunner):
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states_clone: torch.Tensor | None = None
|
||||
shared_experts_input: torch.Tensor | None = None
|
||||
if use_shared_experts_stream:
|
||||
assert self.shared_experts_stream is not None
|
||||
assert self.moe_config.disable_inplace
|
||||
|
||||
shared_experts_input = (
|
||||
shared_input if shared_input is not None else hidden_states
|
||||
)
|
||||
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = shared_experts_input.clone()
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# Record that the shared_experts_input will be used in the
|
||||
# shared_experts_stream to to avoid gc issue from
|
||||
# deallocation. For more details:
|
||||
# https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: We don't need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
shared_experts_input.record_stream(self.shared_experts_stream)
|
||||
|
||||
# Mark sync start point for the separate shared experts
|
||||
# stream here since we want to run in parallel with the
|
||||
@@ -265,7 +263,7 @@ class DefaultMoERunner(MoERunner):
|
||||
assert self.shared_experts_stream is not None
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
return use_shared_experts_stream, hidden_states_clone
|
||||
return use_shared_experts_stream, shared_experts_input
|
||||
|
||||
def ensure_dp_chunking_init(self):
|
||||
if not self.use_dp_chunking or self.batched_hidden_states is not None:
|
||||
@@ -584,7 +582,7 @@ class DefaultMoERunner(MoERunner):
|
||||
|
||||
use_chunked_impl = self.use_dp_chunking
|
||||
|
||||
use_shared_experts_stream, hidden_states_clone = (
|
||||
use_shared_experts_stream, shared_experts_input = (
|
||||
self._maybe_setup_shared_experts_stream(
|
||||
hidden_states,
|
||||
shared_input,
|
||||
@@ -726,7 +724,7 @@ class DefaultMoERunner(MoERunner):
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
# Note that hidden_states clone() is necessary here to avoid
|
||||
# conflict with the main stream
|
||||
shared_output = self.shared_experts(hidden_states_clone)
|
||||
shared_output = self.shared_experts(shared_experts_input)
|
||||
current_stream().wait_stream(self.shared_experts_stream)
|
||||
|
||||
final_hidden_states = (
|
||||
|
||||
@@ -175,7 +175,7 @@ class MiniCPMMoE(nn.Module):
|
||||
)
|
||||
|
||||
final_hidden_states = fused_experts(
|
||||
hidden_states, self.ws, self.w2s, topk_weights, topk_ids, inplace=True
|
||||
hidden_states, self.ws, self.w2s, topk_weights, topk_ids, inplace=False
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
|
||||
Reference in New Issue
Block a user