[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith
2025-09-27 10:22:28 -04:00
committed by GitHub
parent f9df8b4ad7
commit a5354b3ed2
23 changed files with 540 additions and 375 deletions

View File

@@ -28,6 +28,8 @@ class Cache:
class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group):
self.cpu_group = cpu_group
@@ -40,6 +42,7 @@ class All2AllManagerBase:
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
@@ -60,17 +63,21 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def set_num_sms(self, num_sms: int):
pass
def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def destroy(self):
@@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
module.quant_method.init_prepare_finalize(module)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.