Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
@@ -1690,10 +1690,6 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
|
||||
def reduce_output(states: torch.Tensor) -> torch.Tensor:
|
||||
# Slice before all_reduce to enable possible fusion
|
||||
if self.hidden_size != og_hidden_states:
|
||||
states = states[..., :og_hidden_states]
|
||||
|
||||
if (
|
||||
not self.is_sequence_parallel
|
||||
and not self.use_dp_chunking
|
||||
@@ -1716,12 +1712,11 @@ class FusedMoE(CustomOp):
|
||||
if self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(fused_output, tuple)
|
||||
fused_output, zero_expert_result = fused_output
|
||||
return (
|
||||
reduce_output(fused_output)
|
||||
+ zero_expert_result[..., :og_hidden_states]
|
||||
)
|
||||
return (reduce_output(fused_output) + zero_expert_result)[
|
||||
..., :og_hidden_states
|
||||
]
|
||||
else:
|
||||
return reduce_output(fused_output)
|
||||
return reduce_output(fused_output)[..., :og_hidden_states]
|
||||
else:
|
||||
if current_platform.is_tpu():
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we
|
||||
@@ -1734,8 +1729,8 @@ class FusedMoE(CustomOp):
|
||||
hidden_states, router_logits, self.layer_name
|
||||
)
|
||||
return (
|
||||
reduce_output(shared_output),
|
||||
reduce_output(fused_output),
|
||||
reduce_output(shared_output)[..., :og_hidden_states],
|
||||
reduce_output(fused_output)[..., :og_hidden_states],
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
|
||||
Reference in New Issue
Block a user