Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
'''
|
||||
"""
|
||||
Expert parallelism load balancer (EPLB).
|
||||
'''
|
||||
"""
|
||||
|
||||
from .eplb_state import *
|
||||
from .rebalance_algo import *
|
||||
|
||||
@@ -35,8 +35,11 @@ import torch
|
||||
from torch.distributed import ProcessGroup, all_reduce
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
|
||||
in_the_same_node_as)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_ep_group,
|
||||
get_node_count,
|
||||
in_the_same_node_as,
|
||||
)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
@@ -190,11 +193,10 @@ class EplbState:
|
||||
"""
|
||||
Build the initial EPLB state.
|
||||
"""
|
||||
physical_to_logical_map_list = (
|
||||
cls.build_initial_global_physical_to_logical_map(
|
||||
model.num_routed_experts,
|
||||
model.num_redundant_experts,
|
||||
))
|
||||
physical_to_logical_map_list = cls.build_initial_global_physical_to_logical_map(
|
||||
model.num_routed_experts,
|
||||
model.num_redundant_experts,
|
||||
)
|
||||
physical_to_logical_map = torch.tensor(
|
||||
physical_to_logical_map_list,
|
||||
device=device,
|
||||
@@ -205,7 +207,8 @@ class EplbState:
|
||||
MAX_EXPERT_REDUNDANCY = 1023
|
||||
assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
|
||||
f"num_redundant_experts {model.num_redundant_experts} "
|
||||
f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}")
|
||||
f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}"
|
||||
)
|
||||
max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
|
||||
logical_to_physical_map = torch.full(
|
||||
(model.num_logical_experts, max_slots_per_logical_expert),
|
||||
@@ -213,31 +216,42 @@ class EplbState:
|
||||
device=device,
|
||||
)
|
||||
logical_replica_count = torch.zeros(
|
||||
(model.num_logical_experts, ),
|
||||
(model.num_logical_experts,),
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
for i in range(model.num_physical_experts):
|
||||
logical_idx = physical_to_logical_map[i]
|
||||
logical_to_physical_map[logical_idx,
|
||||
logical_replica_count[logical_idx]] = i
|
||||
logical_to_physical_map[logical_idx, logical_replica_count[logical_idx]] = i
|
||||
logical_replica_count[logical_idx] += 1
|
||||
|
||||
# Duplicate initial mapping for all layers
|
||||
physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand(
|
||||
model.num_moe_layers,
|
||||
-1,
|
||||
).contiguous()
|
||||
logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand(
|
||||
model.num_moe_layers,
|
||||
-1,
|
||||
-1,
|
||||
).contiguous()
|
||||
logical_replica_count = logical_replica_count.unsqueeze(0).expand(
|
||||
model.num_moe_layers,
|
||||
-1,
|
||||
).contiguous()
|
||||
physical_to_logical_map = (
|
||||
physical_to_logical_map.unsqueeze(0)
|
||||
.expand(
|
||||
model.num_moe_layers,
|
||||
-1,
|
||||
)
|
||||
.contiguous()
|
||||
)
|
||||
logical_to_physical_map = (
|
||||
logical_to_physical_map.unsqueeze(0)
|
||||
.expand(
|
||||
model.num_moe_layers,
|
||||
-1,
|
||||
-1,
|
||||
)
|
||||
.contiguous()
|
||||
)
|
||||
logical_replica_count = (
|
||||
logical_replica_count.unsqueeze(0)
|
||||
.expand(
|
||||
model.num_moe_layers,
|
||||
-1,
|
||||
)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
expert_load_pass = torch.zeros(
|
||||
(model.num_moe_layers, model.num_physical_experts),
|
||||
@@ -246,21 +260,21 @@ class EplbState:
|
||||
)
|
||||
expert_load_window_size = parallel_config.eplb_config.window_size
|
||||
expert_load_window = torch.zeros(
|
||||
(expert_load_window_size, model.num_moe_layers,
|
||||
model.num_physical_experts),
|
||||
(expert_load_window_size, model.num_moe_layers, model.num_physical_experts),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set the initial progress of rearrangement to 3/4
|
||||
eplb_step_interval = parallel_config.eplb_config.step_interval
|
||||
expert_rearrangement_step = max(
|
||||
0, eplb_step_interval - eplb_step_interval // 4)
|
||||
expert_rearrangement_step = max(0, eplb_step_interval - eplb_step_interval // 4)
|
||||
|
||||
if global_expert_load is not None:
|
||||
ep_group = get_ep_group().device_group
|
||||
assert global_expert_load.shape == (model.num_moe_layers,
|
||||
model.num_logical_experts)
|
||||
assert global_expert_load.shape == (
|
||||
model.num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
)
|
||||
assert global_expert_load.dtype == torch.int64
|
||||
|
||||
num_replicas = model.num_physical_experts
|
||||
@@ -273,20 +287,21 @@ class EplbState:
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
f"{num_gpus=}, {num_nodes=}")
|
||||
f"{num_gpus=}, {num_nodes=}"
|
||||
)
|
||||
|
||||
# Get new expert mappings
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = (rebalance_experts(
|
||||
) = rebalance_experts(
|
||||
global_expert_load,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
))
|
||||
)
|
||||
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert max_physical_slots <= logical_to_physical_map.shape[-1]
|
||||
@@ -326,11 +341,13 @@ class EplbState:
|
||||
expert_rearrangement_step_interval=eplb_step_interval,
|
||||
)
|
||||
|
||||
def step(self,
|
||||
model: MixtureOfExperts,
|
||||
is_dummy: bool = False,
|
||||
is_profile: bool = False,
|
||||
log_stats: bool = False) -> None:
|
||||
def step(
|
||||
self,
|
||||
model: MixtureOfExperts,
|
||||
is_dummy: bool = False,
|
||||
is_profile: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Step the EPLB state.
|
||||
|
||||
@@ -369,32 +386,40 @@ class EplbState:
|
||||
all_reduce(total_expert_load_pass, group=ep_group)
|
||||
|
||||
# num_tokens_per_rank: (num_moe_layers, num_ranks)
|
||||
num_tokens_per_rank = total_expert_load_pass.reshape(
|
||||
total_expert_load_pass.shape[0], ep_group.size(),
|
||||
-1).sum(dim=-1).float()
|
||||
num_tokens_per_rank = (
|
||||
total_expert_load_pass.reshape(
|
||||
total_expert_load_pass.shape[0], ep_group.size(), -1
|
||||
)
|
||||
.sum(dim=-1)
|
||||
.float()
|
||||
)
|
||||
|
||||
# Compute balancedness ratio:
|
||||
# for each layer:
|
||||
# (mean load across ranks) / (max load across ranks)
|
||||
avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
|
||||
max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(
|
||||
dim=0)
|
||||
max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0)
|
||||
|
||||
# Just to make type checker happy
|
||||
tokens_tensors: list[float] = torch.stack(
|
||||
[avg_tokens_tensor, max_tokens_tensor]).tolist()
|
||||
[avg_tokens_tensor, max_tokens_tensor]
|
||||
).tolist()
|
||||
avg_tokens, max_tokens = tokens_tensors
|
||||
balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0
|
||||
|
||||
if ep_group.rank() == 0:
|
||||
logger.info(
|
||||
"EPLB step: avg_tokens=%.2f, max_tokens=%d, "
|
||||
"balancedness=%.4f", avg_tokens, max_tokens, balancedness)
|
||||
"EPLB step: avg_tokens=%.2f, max_tokens=%d, balancedness=%.4f",
|
||||
avg_tokens,
|
||||
max_tokens,
|
||||
balancedness,
|
||||
)
|
||||
|
||||
# Update the expert load sliding window
|
||||
if not is_dummy:
|
||||
self.expert_load_window[self.expert_load_window_step] = (
|
||||
self.expert_load_pass.clone())
|
||||
self.expert_load_pass.clone()
|
||||
)
|
||||
self.expert_load_window_step += 1
|
||||
if self.expert_load_window_step >= self.expert_load_window_size:
|
||||
self.expert_load_window_step = 0
|
||||
@@ -405,8 +430,7 @@ class EplbState:
|
||||
# rearrangement step and perform rearrangement to ensure all ranks are
|
||||
# performing collective communication.
|
||||
self.expert_rearrangement_step += 1
|
||||
if (self.expert_rearrangement_step
|
||||
>= self.expert_rearrangement_step_interval):
|
||||
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
|
||||
self.expert_rearrangement_step = 0
|
||||
self.rearrange(model)
|
||||
|
||||
@@ -416,8 +440,8 @@ class EplbState:
|
||||
is_profile: bool = False,
|
||||
execute_shuffle: bool = True,
|
||||
global_expert_load: Optional[torch.Tensor] = None,
|
||||
rank_mapping: Optional[dict[int,
|
||||
int]] = None) -> Optional[torch.Tensor]:
|
||||
rank_mapping: Optional[dict[int, int]] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Rearrange the experts according to the current load.
|
||||
"""
|
||||
@@ -430,8 +454,7 @@ class EplbState:
|
||||
if is_main_rank:
|
||||
torch.cuda.synchronize()
|
||||
time_start = time.perf_counter()
|
||||
logger.info("Rearranging experts %s...",
|
||||
"(profile)" if is_profile else "")
|
||||
logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
|
||||
|
||||
if global_expert_load is None:
|
||||
# Map the physical expert load to global logical experts
|
||||
@@ -444,23 +467,25 @@ class EplbState:
|
||||
)
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=self.physical_to_logical_map.unsqueeze(0).expand_as(
|
||||
self.expert_load_window).long(),
|
||||
index=self.physical_to_logical_map.unsqueeze(0)
|
||||
.expand_as(self.expert_load_window)
|
||||
.long(),
|
||||
src=self.expert_load_window,
|
||||
)
|
||||
|
||||
if not execute_shuffle:
|
||||
metadata = torch.tensor(
|
||||
[
|
||||
model.num_moe_layers, model.num_logical_experts,
|
||||
self.physical_to_logical_map.shape[1]
|
||||
model.num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
self.physical_to_logical_map.shape[1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.broadcast(metadata,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0)
|
||||
torch.distributed.broadcast(
|
||||
metadata, group=get_ep_group().cpu_group, group_src=0
|
||||
)
|
||||
|
||||
# Perform all-reduce to get the expert load across all ranks
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
@@ -469,9 +494,9 @@ class EplbState:
|
||||
if not execute_shuffle:
|
||||
# (num_moe_layers, old_num_physical_experts)
|
||||
old_global_expert_indices = self.physical_to_logical_map
|
||||
torch.distributed.broadcast(old_global_expert_indices,
|
||||
group=ep_group,
|
||||
group_src=0)
|
||||
torch.distributed.broadcast(
|
||||
old_global_expert_indices, group=ep_group, group_src=0
|
||||
)
|
||||
return global_expert_load_window
|
||||
else:
|
||||
assert execute_shuffle
|
||||
@@ -486,10 +511,10 @@ class EplbState:
|
||||
# the GPUs to be released.
|
||||
cpu_group = get_ep_group().cpu_group
|
||||
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
|
||||
num_gpus = sum(new_rank != -1
|
||||
for new_rank in rank_mapping.values())
|
||||
num_replicas = num_replicas // ep_group.size(
|
||||
) * num_gpus # handle num replicas change
|
||||
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
|
||||
num_replicas = (
|
||||
num_replicas // ep_group.size() * num_gpus
|
||||
) # handle num replicas change
|
||||
else:
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
@@ -499,20 +524,21 @@ class EplbState:
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
f"{num_gpus=}, {num_nodes=}")
|
||||
f"{num_gpus=}, {num_nodes=}"
|
||||
)
|
||||
|
||||
# Get new expert mappings
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = (rebalance_experts(
|
||||
) = rebalance_experts(
|
||||
global_expert_load_window,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
))
|
||||
)
|
||||
|
||||
# Update expert weights
|
||||
rearrange_expert_weights_inplace(
|
||||
@@ -525,18 +551,20 @@ class EplbState:
|
||||
)
|
||||
|
||||
if not is_profile:
|
||||
if self.physical_to_logical_map.shape[
|
||||
1] != new_physical_to_logical_map.shape[1]:
|
||||
if (
|
||||
self.physical_to_logical_map.shape[1]
|
||||
!= new_physical_to_logical_map.shape[1]
|
||||
):
|
||||
self.physical_to_logical_map = new_physical_to_logical_map.to(
|
||||
self.physical_to_logical_map.device)
|
||||
self.physical_to_logical_map.device
|
||||
)
|
||||
else:
|
||||
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0,
|
||||
self.logical_to_physical_map.shape[-1] - max_physical_slots),
|
||||
(0, self.logical_to_physical_map.shape[-1] - max_physical_slots),
|
||||
value=-1,
|
||||
)
|
||||
self.logical_to_physical_map.copy_(new_logical_to_physical_map)
|
||||
@@ -560,11 +588,10 @@ class EplbState:
|
||||
"""
|
||||
ep_group = get_ep_group()
|
||||
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(metadata,
|
||||
group=ep_group.cpu_group,
|
||||
group_src=0)
|
||||
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
|
||||
num_moe_layers, num_logical_experts, num_old_physical_experts = (
|
||||
metadata.tolist())
|
||||
metadata.tolist()
|
||||
)
|
||||
global_expert_load = torch.zeros(
|
||||
(num_moe_layers, num_logical_experts),
|
||||
dtype=torch.int64,
|
||||
@@ -576,9 +603,9 @@ class EplbState:
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
torch.distributed.broadcast(old_global_expert_indices,
|
||||
group=ep_group.device_group,
|
||||
group_src=0)
|
||||
torch.distributed.broadcast(
|
||||
old_global_expert_indices, group=ep_group.device_group, group_src=0
|
||||
)
|
||||
|
||||
return global_expert_load, old_global_expert_indices
|
||||
|
||||
|
||||
@@ -15,8 +15,9 @@ on how the EPLB algorithm works.
|
||||
import torch
|
||||
|
||||
|
||||
def balanced_packing(weight: torch.Tensor,
|
||||
num_packs: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def balanced_packing(
|
||||
weight: torch.Tensor, num_packs: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pack n weighted objects to m packs, such that each bin contains exactly
|
||||
n/m objects and the weights of all packs are as balanced as possible.
|
||||
@@ -34,25 +35,21 @@ def balanced_packing(weight: torch.Tensor,
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = torch.arange(weight.size(-1),
|
||||
dtype=torch.int64,
|
||||
device=weight.device).expand(weight.shape)
|
||||
pack_index = torch.arange(
|
||||
weight.size(-1), dtype=torch.int64, device=weight.device
|
||||
).expand(weight.shape)
|
||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
indices = weight.float().sort(-1, descending=True).indices.cpu()
|
||||
pack_index = torch.full_like(weight,
|
||||
fill_value=-1,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
|
||||
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
|
||||
for i in range(num_layers):
|
||||
pack_weights = [0] * num_packs
|
||||
pack_items = [0] * num_packs
|
||||
for group in indices[i]:
|
||||
pack = min(
|
||||
(i
|
||||
for i in range(num_packs) if pack_items[i] < groups_per_pack),
|
||||
(i for i in range(num_packs) if pack_items[i] < groups_per_pack),
|
||||
key=pack_weights.__getitem__,
|
||||
)
|
||||
assert pack_items[pack] < groups_per_pack
|
||||
@@ -64,8 +61,8 @@ def balanced_packing(weight: torch.Tensor,
|
||||
|
||||
|
||||
def replicate_experts(
|
||||
weight: torch.Tensor,
|
||||
num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
weight: torch.Tensor, num_phy: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
||||
load of all replicas is minimized.
|
||||
@@ -83,8 +80,7 @@ def replicate_experts(
|
||||
num_redundant = num_phy - num_log
|
||||
assert num_redundant >= 0
|
||||
device = weight.device
|
||||
phy2log = torch.arange(num_phy, dtype=torch.int64,
|
||||
device=device).repeat(n, 1)
|
||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||
@@ -108,7 +104,7 @@ def rebalance_experts_hierarchical(
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
num_physical_experts: number of physical experts after replication
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g., NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
@@ -134,45 +130,51 @@ def rebalance_experts_hierarchical(
|
||||
inv.scatter_(
|
||||
1,
|
||||
perm,
|
||||
torch.arange(perm.size(1), dtype=torch.int64,
|
||||
device=perm.device).expand(perm.shape),
|
||||
torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(
|
||||
perm.shape
|
||||
),
|
||||
)
|
||||
return inv
|
||||
|
||||
# Step 1: pack groups to nodes
|
||||
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
||||
group_pack_index, group_rank_in_pack = balanced_packing(
|
||||
tokens_per_group, num_nodes)
|
||||
log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) *
|
||||
group_size).unsqueeze(-1) +
|
||||
torch.arange(group_size,
|
||||
dtype=torch.int64,
|
||||
device=group_pack_index.device)).flatten(-2)
|
||||
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
|
||||
log2mlog = (
|
||||
(
|
||||
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
|
||||
).unsqueeze(-1)
|
||||
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
|
||||
).flatten(-2)
|
||||
mlog2log = inverse(log2mlog)
|
||||
|
||||
# Step 2: construct redundant experts within nodes
|
||||
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
||||
-1, num_logical_experts // num_nodes)
|
||||
-1, num_logical_experts // num_nodes
|
||||
)
|
||||
phy2mlog, phyrank, mlogcnt = replicate_experts(
|
||||
tokens_per_mlog, num_physical_experts // num_nodes)
|
||||
tokens_per_mlog, num_physical_experts // num_nodes
|
||||
)
|
||||
|
||||
# Step 3: pack physical_experts to GPUs
|
||||
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
||||
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
||||
pack_index, rank_in_pack = balanced_packing(tokens_per_phy,
|
||||
num_gpus // num_nodes)
|
||||
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
|
||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||
pphy2phy = inverse(phy2pphy)
|
||||
|
||||
pphy2mlog = phy2mlog.gather(
|
||||
-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
|
||||
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange(
|
||||
0,
|
||||
num_logical_experts,
|
||||
num_logical_experts // num_nodes,
|
||||
device=group_pack_index.device,
|
||||
).view(1, -1, 1)).flatten(-2)
|
||||
-1, pphy2phy
|
||||
) # [num_layers * num_nodes, num_log_per_nodes]
|
||||
pphy2mlog = (
|
||||
pphy2mlog.view(num_layers, num_nodes, -1)
|
||||
+ torch.arange(
|
||||
0,
|
||||
num_logical_experts,
|
||||
num_logical_experts // num_nodes,
|
||||
device=group_pack_index.device,
|
||||
).view(1, -1, 1)
|
||||
).flatten(-2)
|
||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||
@@ -214,11 +216,13 @@ def rebalance_experts(
|
||||
if num_groups % num_nodes == 0:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus)
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, 1, 1, num_gpus)
|
||||
weight, num_replicas, 1, 1, num_gpus
|
||||
)
|
||||
num_redundant_experts = num_replicas - num_logical_experts
|
||||
maxlogcnt = num_redundant_experts + 1
|
||||
log2phy: torch.Tensor = torch.full(
|
||||
@@ -230,8 +234,9 @@ def rebalance_experts(
|
||||
log2phy.view(num_layers, -1).scatter_(
|
||||
-1,
|
||||
phy2log * maxlogcnt + phyrank,
|
||||
torch.arange(num_replicas, dtype=torch.int64,
|
||||
device=log2phy.device).expand(num_layers, -1),
|
||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
||||
num_layers, -1
|
||||
),
|
||||
)
|
||||
return phy2log, log2phy, logcnt
|
||||
|
||||
|
||||
@@ -11,8 +11,13 @@ from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import (P2POp, ProcessGroup, all_gather,
|
||||
batch_isend_irecv, get_global_rank)
|
||||
from torch.distributed import (
|
||||
P2POp,
|
||||
ProcessGroup,
|
||||
all_gather,
|
||||
batch_isend_irecv,
|
||||
get_global_rank,
|
||||
)
|
||||
|
||||
|
||||
def idx_local_to_global(
|
||||
@@ -132,8 +137,7 @@ def shuffle_layer(
|
||||
continue
|
||||
if old_indices[src_global] == new_indices[dst_global]:
|
||||
is_received_locally[dst] = True
|
||||
for weight, buffer in zip(expert_weights,
|
||||
expert_weights_buffer):
|
||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
||||
buffer[dst].copy_(weight[src])
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
@@ -177,7 +181,8 @@ def shuffle_layer(
|
||||
torch.distributed.isend,
|
||||
weight[src],
|
||||
dst_global,
|
||||
) for weight in expert_weights
|
||||
)
|
||||
for weight in expert_weights
|
||||
]
|
||||
|
||||
# 3. Initiate receiving of weights.
|
||||
@@ -216,7 +221,8 @@ def shuffle_layer(
|
||||
torch.distributed.irecv,
|
||||
weight[dst],
|
||||
src_global,
|
||||
) for weight in expert_weights_buffer
|
||||
)
|
||||
for weight in expert_weights_buffer
|
||||
]
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
@@ -271,29 +277,25 @@ def rearrange_expert_weights_inplace(
|
||||
if rank_mapping is not None:
|
||||
if len(rank_mapping) == ep_group.size():
|
||||
# scale down
|
||||
new_global_expert_indices = \
|
||||
_map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices,
|
||||
rank_mapping,
|
||||
)
|
||||
else:
|
||||
# scale up
|
||||
old_global_expert_indices = \
|
||||
_map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices,
|
||||
rank_mapping,
|
||||
ep_group.size(),
|
||||
)
|
||||
|
||||
assert old_global_expert_indices.shape[
|
||||
1] == new_global_expert_indices.shape[1]
|
||||
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
|
||||
|
||||
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
|
||||
assert len(expert_weights) == num_moe_layers
|
||||
|
||||
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
||||
assert new_global_expert_indices.shape == (num_moe_layers,
|
||||
num_physical_experts)
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
|
||||
ep_rank = ep_group.rank()
|
||||
ep_size = ep_group.size()
|
||||
@@ -342,13 +344,13 @@ def _map_old_expert_indices_with_rank_mapping(
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map the old global expert indices to the new global expert indices.
|
||||
|
||||
|
||||
Args:
|
||||
old_global_expert_indices:
|
||||
Shape (num_layers, old_ep_size * num_local_physical_experts).
|
||||
rank_mapping: Mapping from old rank to new rank.
|
||||
new_ep_size: New expert parallelism size.
|
||||
|
||||
|
||||
Returns:
|
||||
Mapped expert indices with shape
|
||||
(num_layers, new_ep_size * num_local_physical_experts).
|
||||
@@ -379,8 +381,9 @@ def _map_old_expert_indices_with_rank_mapping(
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, new_start_idx:new_end_idx] = \
|
||||
mapped_expert_indices[:, new_start_idx:new_end_idx] = (
|
||||
old_global_expert_indices[:, old_start_idx:old_end_idx]
|
||||
)
|
||||
# If new_rank is None or >= new_ep_size, the experts remain -1
|
||||
# (scale down case)
|
||||
|
||||
@@ -415,8 +418,9 @@ def _map_new_expert_indices_with_rank_mapping(
|
||||
new_start_idx = new_rank * num_local_physical_experts
|
||||
new_end_idx = (new_rank + 1) * num_local_physical_experts
|
||||
|
||||
mapped_expert_indices[:, old_start_idx:old_end_idx] = \
|
||||
mapped_expert_indices[:, old_start_idx:old_end_idx] = (
|
||||
new_global_expert_indices[:, new_start_idx:new_end_idx]
|
||||
)
|
||||
|
||||
return mapped_expert_indices
|
||||
|
||||
|
||||
Reference in New Issue
Block a user