[MoE Refactor] Integrate Naive Prepare Finalize into MK (#32567)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: amirkl94 <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -84,6 +84,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if extra_tensors is not None:
|
||||
raise NotImplementedError(
|
||||
"extra_tensors is not supported for NaiveAll2AllManager"
|
||||
)
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(
|
||||
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
topk_weights = self.naive_multicast(
|
||||
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
topk_ids = self.naive_multicast(
|
||||
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
@@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -148,6 +176,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
|
||||
return gathered_tensors[0], gathered_tensors[1]
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
|
||||
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
|
||||
if extra_tensors is not None:
|
||||
tensors_to_gather.extend(extra_tensors)
|
||||
|
||||
gathered_tensors = dist_group.all_gatherv(
|
||||
tensors_to_gather,
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
|
||||
hidden_states = gathered_tensors[0]
|
||||
topk_weights = gathered_tensors[1]
|
||||
topk_ids = gathered_tensors[2]
|
||||
|
||||
if extra_tensors is None:
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
@@ -216,7 +284,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -225,6 +293,19 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
@@ -264,7 +345,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -273,6 +354,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user