Elastic Expert Parallel Initial Support (#20775)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
@@ -29,12 +29,15 @@ physical experts.
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import all_gather, all_reduce
|
||||
from torch.distributed import ProcessGroup, all_gather, all_reduce
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_ep_group, get_node_count
|
||||
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
|
||||
in_the_same_node_as)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
|
||||
@@ -172,6 +175,9 @@ class EplbState:
|
||||
model: MixtureOfExperts,
|
||||
device: torch.device,
|
||||
parallel_config: ParallelConfig,
|
||||
global_expert_load: Optional[torch.Tensor] = None,
|
||||
old_global_expert_indices: Optional[torch.Tensor] = None,
|
||||
rank_mapping: Optional[dict[int, int]] = None,
|
||||
) -> "EplbState":
|
||||
"""
|
||||
Build the initial EPLB state.
|
||||
@@ -185,8 +191,16 @@ class EplbState:
|
||||
physical_to_logical_map_list,
|
||||
device=device,
|
||||
)
|
||||
# Assuming 8 GPUs per node, this supports up to
|
||||
# (1023 + 1) / 8 = 128 nodes for now.
|
||||
# TODO(rui): make this configurable
|
||||
MAX_EXPERT_REDUNDANCY = 1023
|
||||
assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
|
||||
f"num_redundant_experts {model.num_redundant_experts} "
|
||||
f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}")
|
||||
max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
|
||||
logical_to_physical_map = torch.full(
|
||||
(model.num_logical_experts, model.num_redundant_experts + 1),
|
||||
(model.num_logical_experts, max_slots_per_logical_expert),
|
||||
-1,
|
||||
device=device,
|
||||
)
|
||||
@@ -235,11 +249,63 @@ class EplbState:
|
||||
expert_rearrangement_step = max(
|
||||
0, eplb_step_interval - eplb_step_interval // 4)
|
||||
|
||||
if global_expert_load is not None:
|
||||
ep_group = get_ep_group().device_group
|
||||
assert global_expert_load.shape == (model.num_moe_layers,
|
||||
model.num_logical_experts)
|
||||
assert global_expert_load.dtype == torch.int64
|
||||
|
||||
num_replicas = model.num_physical_experts
|
||||
num_groups = model.num_expert_groups
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
|
||||
if num_gpus % num_nodes != 0:
|
||||
num_nodes = 1
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
f"{num_gpus=}, {num_nodes=}")
|
||||
|
||||
# Get new expert mappings
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = (rebalance_experts(
|
||||
global_expert_load,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
))
|
||||
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert max_physical_slots <= logical_to_physical_map.shape[-1]
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
|
||||
value=-1,
|
||||
)
|
||||
physical_to_logical_map = new_physical_to_logical_map.to(device)
|
||||
logical_to_physical_map.copy_(new_logical_to_physical_map)
|
||||
logical_replica_count.copy_(new_logical_replica_count)
|
||||
|
||||
model.set_eplb_state(
|
||||
expert_load_pass,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
if global_expert_load is not None:
|
||||
rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices,
|
||||
new_physical_to_logical_map,
|
||||
model.expert_weights,
|
||||
ep_group,
|
||||
False,
|
||||
rank_mapping,
|
||||
)
|
||||
expert_rearrangement_step = 0
|
||||
|
||||
return cls(
|
||||
physical_to_logical_map,
|
||||
@@ -337,7 +403,10 @@ class EplbState:
|
||||
|
||||
def rearrange(self,
|
||||
model: MixtureOfExperts,
|
||||
is_profile: bool = False) -> None:
|
||||
is_profile: bool = False,
|
||||
execute_shuffle: bool = True,
|
||||
global_expert_load: Optional[torch.Tensor] = None,
|
||||
rank_mapping: Optional[dict[int, int]] = None) -> None:
|
||||
"""
|
||||
Rearrange the experts according to the current load.
|
||||
"""
|
||||
@@ -353,42 +422,79 @@ class EplbState:
|
||||
logger.info("Rearranging experts %s...",
|
||||
"(profile)" if is_profile else "")
|
||||
|
||||
# This mapping is only used here, so we do not store it in the state
|
||||
physical_expert_start = ep_rank * model.num_local_physical_experts
|
||||
physical_expert_end = (physical_expert_start +
|
||||
model.num_local_physical_experts)
|
||||
# (num_moe_layers, num_local_physical_experts)
|
||||
local_physical_to_logical_map = self.physical_to_logical_map[
|
||||
:,
|
||||
physical_expert_start:physical_expert_end,
|
||||
]
|
||||
if global_expert_load is None:
|
||||
# This mapping is only used here, so we do not store it in the state
|
||||
physical_expert_start = ep_rank * model.num_local_physical_experts
|
||||
physical_expert_end = (physical_expert_start +
|
||||
model.num_local_physical_experts)
|
||||
# (num_moe_layers, num_local_physical_experts)
|
||||
local_physical_to_logical_map = self.physical_to_logical_map[
|
||||
:,
|
||||
physical_expert_start:physical_expert_end,
|
||||
]
|
||||
|
||||
# Map the local physical expert load to global logical experts
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
model.num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
dtype=self.expert_load_window.dtype,
|
||||
device=self.expert_load_window.device,
|
||||
)
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
|
||||
self.expert_load_window).long(),
|
||||
src=self.expert_load_window,
|
||||
)
|
||||
# Map the local physical expert load to global logical experts
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
model.num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
dtype=self.expert_load_window.dtype,
|
||||
device=self.expert_load_window.device,
|
||||
)
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
|
||||
self.expert_load_window).long(),
|
||||
src=self.expert_load_window,
|
||||
)
|
||||
|
||||
# Perform all-reduce to get the expert load across all ranks
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
all_reduce(global_expert_load_window, group=ep_group)
|
||||
if not execute_shuffle:
|
||||
metadata = torch.tensor(
|
||||
[
|
||||
model.num_moe_layers, model.num_logical_experts,
|
||||
self.physical_to_logical_map.shape[1]
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.broadcast(metadata,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0)
|
||||
|
||||
# Perform all-reduce to get the expert load across all ranks
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
all_reduce(global_expert_load_window, group=ep_group)
|
||||
|
||||
if not execute_shuffle:
|
||||
# (num_moe_layers, old_num_physical_experts)
|
||||
old_global_expert_indices = self.physical_to_logical_map
|
||||
torch.distributed.broadcast(old_global_expert_indices,
|
||||
group=ep_group,
|
||||
group_src=0)
|
||||
return global_expert_load_window
|
||||
else:
|
||||
assert execute_shuffle
|
||||
global_expert_load_window = global_expert_load
|
||||
|
||||
# TODO(bowen): Treat differently for prefill and decode nodes
|
||||
num_replicas = model.num_physical_experts
|
||||
num_groups = model.num_expert_groups
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
if rank_mapping is not None and len(rank_mapping) == ep_group.size():
|
||||
# NOTE(yongji): scale down, we need to rebalance the experts on
|
||||
# remaining GPUs, transfer the experts while we haven't shutdown
|
||||
# the GPUs to be released.
|
||||
cpu_group = get_ep_group().cpu_group
|
||||
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
|
||||
num_gpus = sum(new_rank != -1
|
||||
for new_rank in rank_mapping.values())
|
||||
num_replicas = num_replicas // ep_group.size(
|
||||
) * num_gpus # handle num replicas change
|
||||
else:
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
|
||||
if num_gpus % num_nodes != 0:
|
||||
self.num_nodes = 1
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
@@ -414,10 +520,24 @@ class EplbState:
|
||||
model.expert_weights,
|
||||
ep_group,
|
||||
is_profile,
|
||||
rank_mapping,
|
||||
)
|
||||
|
||||
if not is_profile:
|
||||
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
|
||||
if self.physical_to_logical_map.shape[
|
||||
1] != new_physical_to_logical_map.shape[1]:
|
||||
self.physical_to_logical_map = new_physical_to_logical_map.to(
|
||||
self.physical_to_logical_map.device)
|
||||
else:
|
||||
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0,
|
||||
self.logical_to_physical_map.shape[-1] - max_physical_slots),
|
||||
value=-1,
|
||||
)
|
||||
self.logical_to_physical_map.copy_(new_logical_to_physical_map)
|
||||
self.logical_replica_count.copy_(new_logical_replica_count)
|
||||
|
||||
@@ -430,3 +550,69 @@ class EplbState:
|
||||
" (profile) " if is_profile else " ",
|
||||
time_end - time_start,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Receive the expert load and old placement from the master rank.
|
||||
"""
|
||||
ep_group = get_ep_group()
|
||||
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(metadata,
|
||||
group=ep_group.cpu_group,
|
||||
group_src=0)
|
||||
num_moe_layers, num_logical_experts, num_old_physical_experts = (
|
||||
metadata.tolist())
|
||||
global_expert_load = torch.zeros(
|
||||
(num_moe_layers, num_logical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
all_reduce(global_expert_load, group=ep_group.device_group)
|
||||
old_global_expert_indices = torch.empty(
|
||||
(num_moe_layers, num_old_physical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
torch.distributed.broadcast(old_global_expert_indices,
|
||||
group=ep_group.device_group,
|
||||
group_src=0)
|
||||
|
||||
return global_expert_load, old_global_expert_indices
|
||||
|
||||
|
||||
def _node_count_with_rank_mapping(
|
||||
pg: Union[ProcessGroup, StatelessProcessGroup],
|
||||
rank_mapping: dict[int, int],
|
||||
) -> int:
|
||||
if isinstance(pg, ProcessGroup):
|
||||
world_size = torch.distributed.get_world_size(group=pg)
|
||||
else:
|
||||
world_size = pg.world_size
|
||||
|
||||
if world_size == 1:
|
||||
return 1
|
||||
|
||||
# Build node assignment map
|
||||
node_assignment = [0] * world_size # rank -> node_id
|
||||
next_node_id = 0
|
||||
|
||||
for current_rank in range(world_size):
|
||||
if node_assignment[current_rank] != 0:
|
||||
continue # Already assigned to a node
|
||||
|
||||
assert current_rank in rank_mapping
|
||||
if rank_mapping[current_rank] == -1:
|
||||
continue # Pending shutdown
|
||||
|
||||
# Assign current rank to a new node
|
||||
next_node_id += 1
|
||||
node_assignment[current_rank] = next_node_id
|
||||
|
||||
# Find all ranks on the same node as current_rank
|
||||
same_node_flags = in_the_same_node_as(pg, current_rank)
|
||||
for other_rank, is_same_node in enumerate(same_node_flags):
|
||||
if is_same_node and node_assignment[other_rank] == 0:
|
||||
node_assignment[other_rank] = next_node_id
|
||||
|
||||
return next_node_id
|
||||
|
||||
@@ -8,6 +8,7 @@ This involves the exchange of expert weights between GPUs.
|
||||
|
||||
from collections.abc import Iterable, MutableSequence, Sequence
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import (P2POp, ProcessGroup, all_gather,
|
||||
@@ -127,6 +128,8 @@ def shuffle_layer(
|
||||
dst_global = local2global(dst)
|
||||
if is_received_locally[dst]:
|
||||
continue
|
||||
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
|
||||
continue
|
||||
if old_indices[src_global] == new_indices[dst_global]:
|
||||
is_received_locally[dst] = True
|
||||
for weight, buffer in zip(expert_weights,
|
||||
@@ -139,6 +142,8 @@ def shuffle_layer(
|
||||
experts_send_loc: dict[int, int] = {}
|
||||
for src in range(num_local_experts):
|
||||
expert = old_indices[local2global(src)]
|
||||
if expert == -1:
|
||||
continue
|
||||
if expert in experts_send_loc:
|
||||
continue
|
||||
experts_send_loc[expert] = src
|
||||
@@ -181,6 +186,8 @@ def shuffle_layer(
|
||||
if is_received_locally[dst]:
|
||||
continue
|
||||
expert = new_indices[local2global(dst)]
|
||||
if expert == -1:
|
||||
continue
|
||||
if expert in experts_recv_loc:
|
||||
continue
|
||||
experts_recv_loc[expert] = dst
|
||||
@@ -227,6 +234,8 @@ def shuffle_layer(
|
||||
weight[dst].copy_(buffer[dst])
|
||||
else:
|
||||
expert = new_indices[local2global(dst)]
|
||||
if expert == -1:
|
||||
continue
|
||||
src = experts_recv_loc[expert]
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
weight[dst].copy_(buffer[src])
|
||||
@@ -238,6 +247,7 @@ def rearrange_expert_weights_inplace(
|
||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: Optional[dict[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
@@ -256,7 +266,28 @@ def rearrange_expert_weights_inplace(
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
rank_mapping: A dictionary mapping old rank to new rank.
|
||||
"""
|
||||
if rank_mapping is not None:
|
||||
if len(rank_mapping) == ep_group.size():
|
||||
# scale down
|
||||
new_global_expert_indices = \
|
||||
_map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
else:
|
||||
# scale up
|
||||
old_global_expert_indices = \
|
||||
_map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
ep_group.size(),
|
||||
)
|
||||
|
||||
assert old_global_expert_indices.shape[
|
||||
1] == new_global_expert_indices.shape[1]
|
||||
|
||||
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
|
||||
assert len(expert_weights) == num_moe_layers
|
||||
|
||||
@@ -304,4 +335,90 @@ def rearrange_expert_weights_inplace(
|
||||
)
|
||||
|
||||
|
||||
def _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
rank_mapping: dict[int, int],
|
||||
new_ep_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the old global expert indices to the new global expert indices.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices:
|
||||
Shape (num_layers, old_ep_size * num_local_physical_experts).
|
||||
rank_mapping: Mapping from old rank to new rank.
|
||||
new_ep_size: New expert parallelism size.
|
||||
|
||||
Returns:
|
||||
Mapped expert indices with shape
|
||||
(num_layers, new_ep_size * num_local_physical_experts).
|
||||
"""
|
||||
num_layers, old_num_physical_experts = old_global_expert_indices.shape
|
||||
assert rank_mapping, "Rank mapping is required"
|
||||
|
||||
# Get sizes from parameters and rank_mapping
|
||||
old_ep_size = len(rank_mapping)
|
||||
num_local_physical_experts = old_num_physical_experts // old_ep_size
|
||||
new_num_physical_experts = new_ep_size * num_local_physical_experts
|
||||
|
||||
# Create mapped tensor with new shape, initialized to -1
|
||||
mapped_expert_indices = torch.full(
|
||||
(num_layers, new_num_physical_experts),
|
||||
fill_value=-1,
|
||||
dtype=old_global_expert_indices.dtype,
|
||||
device=old_global_expert_indices.device,
|
||||
)
|
||||
|
||||
# Handle rank mapping (scale up/down with rank changes)
|
||||
for old_rank in range(old_ep_size):
|
||||
new_rank = rank_mapping.get(old_rank)
|
||||
if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size:
|
||||
# This old rank exists in the new configuration
|
||||
old_start_idx = old_rank * num_local_physical_experts
|
||||
old_end_idx = (old_rank + 1) * num_local_physical_experts
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, new_start_idx:new_end_idx] = \
|
||||
old_global_expert_indices[:, old_start_idx:old_end_idx]
|
||||
# If new_rank is None or >= new_ep_size, the experts remain -1
|
||||
# (scale down case)
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
def _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
rank_mapping: dict[int, int],
|
||||
) -> torch.Tensor:
|
||||
num_layers, new_num_physical_experts = new_global_expert_indices.shape
|
||||
assert rank_mapping, "Rank mapping is required"
|
||||
|
||||
# Get sizes from parameters and rank_mapping
|
||||
old_ep_size = len(rank_mapping)
|
||||
new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values())
|
||||
num_local_physical_experts = new_num_physical_experts // new_ep_size
|
||||
old_num_physical_experts = old_ep_size * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices = torch.full(
|
||||
(num_layers, old_num_physical_experts),
|
||||
fill_value=-1,
|
||||
dtype=new_global_expert_indices.dtype,
|
||||
device=new_global_expert_indices.device,
|
||||
)
|
||||
|
||||
for old_rank in range(old_ep_size):
|
||||
new_rank = rank_mapping[old_rank]
|
||||
if new_rank >= 0 and new_rank < new_ep_size:
|
||||
old_start_idx = old_rank * num_local_physical_experts
|
||||
old_end_idx = (old_rank + 1) * num_local_physical_experts
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, old_start_idx:old_end_idx] = \
|
||||
new_global_expert_indices[:, new_start_idx:new_end_idx]
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
__all__ = ["rearrange_expert_weights_inplace"]
|
||||
|
||||
Reference in New Issue
Block a user