[PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds (#24248)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Ilya Markov
2025-11-11 00:33:11 +01:00
committed by GitHub
parent 021143561f
commit d17ecc6b19
6 changed files with 1284 additions and 83 deletions

View File

@@ -2356,6 +2356,16 @@ class FusedMoE(CustomOp):
value=0.0,
)
def reduce_output(states: torch.Tensor) -> torch.Tensor:
if (
not self.is_sequence_parallel
and not self.use_dp_chunking
and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)
):
states = self.maybe_all_reduce_tensor_model_parallel(states)
return states
if self.shared_experts is None:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
@@ -2366,7 +2376,14 @@ class FusedMoE(CustomOp):
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, self.layer_name
)
return fused_output[..., :og_hidden_states]
if self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(fused_output, tuple)
fused_output, zero_expert_result = fused_output
return (reduce_output(fused_output) + zero_expert_result)[
..., :og_hidden_states
]
else:
return reduce_output(fused_output)[..., :og_hidden_states]
else:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
@@ -2379,8 +2396,8 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, self.layer_name
)
return (
shared_output[..., :og_hidden_states],
fused_output[..., :og_hidden_states],
reduce_output(shared_output)[..., :og_hidden_states],
reduce_output(fused_output)[..., :og_hidden_states],
)
def forward_cuda(
@@ -2667,31 +2684,21 @@ class FusedMoE(CustomOp):
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
def reduce_output(
states: torch.Tensor, do_combine: bool = True
) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states, self.is_sequence_parallel)
if (
not self.is_sequence_parallel
and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)
):
states = self.maybe_all_reduce_tensor_model_parallel(states)
return states
if self.shared_experts is not None:
return (
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
final_hidden_states[0],
combine_output(final_hidden_states[1]),
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) + zero_expert_result
return (combine_output(final_hidden_states), zero_expert_result)
else:
return reduce_output(final_hidden_states)
return combine_output(final_hidden_states)
@classmethod
def make_expert_params_mapping(