[Bugfix] Fix GPT-OSS AR+NORM fusion (#28841)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-11-25 15:59:40 +08:00
committed by GitHub
parent ef1f7030f0
commit 6330f9477d
4 changed files with 24 additions and 7 deletions

View File

@@ -1690,6 +1690,10 @@ class FusedMoE(CustomOp):
)
def reduce_output(states: torch.Tensor) -> torch.Tensor:
# Slice before all_reduce to enable possible fusion
if self.hidden_size != og_hidden_states:
states = states[..., :og_hidden_states]
if (
not self.is_sequence_parallel
and not self.use_dp_chunking
@@ -1712,11 +1716,12 @@ class FusedMoE(CustomOp):
if self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(fused_output, tuple)
fused_output, zero_expert_result = fused_output
return (reduce_output(fused_output) + zero_expert_result)[
..., :og_hidden_states
]
return (
reduce_output(fused_output)
+ zero_expert_result[..., :og_hidden_states]
)
else:
return reduce_output(fused_output)[..., :og_hidden_states]
return reduce_output(fused_output)
else:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
@@ -1729,8 +1734,8 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, self.layer_name
)
return (
reduce_output(shared_output)[..., :og_hidden_states],
reduce_output(fused_output)[..., :og_hidden_states],
reduce_output(shared_output),
reduce_output(fused_output),
)
def forward_cuda(