[Perf] Move eplb rebalance algo to async thread (#30888)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
@@ -295,12 +295,11 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
for layer_idx in range(num_layers):
|
||||
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
|
||||
transfer_layer(
|
||||
old_global_expert_indices=old_indices_cpu,
|
||||
new_global_expert_indices=new_indices_cpu,
|
||||
expert_weights=expert_weights,
|
||||
old_layer_indices=old_indices_cpu[layer_idx],
|
||||
new_layer_indices=new_indices_cpu[layer_idx],
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffer=expert_buffer,
|
||||
ep_group=ep_group,
|
||||
layer=layer_idx,
|
||||
cuda_stream=cuda_stream,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -11,13 +11,13 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.distributed.parallel_state import get_eplb_group
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .rebalance_execute import transfer_layer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .eplb_state import EplbState
|
||||
from .eplb_state import EplbModelState, EplbState
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -27,8 +27,8 @@ def start_async_worker(
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
is_profile: bool = False,
|
||||
) -> threading.Thread:
|
||||
ep_group = get_ep_group().device_group
|
||||
rank = ep_group.rank()
|
||||
eplb_group = get_eplb_group().device_group
|
||||
rank = eplb_group.rank()
|
||||
device_index = state.cuda_device_index
|
||||
assert state.is_async
|
||||
|
||||
@@ -42,7 +42,7 @@ def start_async_worker(
|
||||
loop.run_until_complete(
|
||||
transfer_run_periodically(
|
||||
state=state,
|
||||
ep_group=ep_group,
|
||||
eplb_group=eplb_group,
|
||||
cuda_stream=cuda_stream,
|
||||
is_profile=is_profile,
|
||||
rank_mapping=rank_mapping,
|
||||
@@ -58,9 +58,53 @@ def start_async_worker(
|
||||
return thread
|
||||
|
||||
|
||||
def run_rebalance_experts(
|
||||
model_state: "EplbModelState",
|
||||
eplb_state: "EplbState",
|
||||
physical_to_logical_map_cpu: torch.Tensor,
|
||||
) -> None:
|
||||
assert model_state.eplb_stats is not None
|
||||
eplb_stats = model_state.eplb_stats
|
||||
|
||||
# Wait for the main thread's all-reduce and clone to complete before
|
||||
# accessing the global_expert_load_window tensor.
|
||||
assert model_state.window_ready_event is not None
|
||||
model_state.window_ready_event.wait()
|
||||
model_state.window_ready_event = None
|
||||
|
||||
# Move the global expert load window to CPU for computation.
|
||||
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
|
||||
# Compute new expert mappings for the model
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = eplb_state.policy.rebalance_experts(
|
||||
global_expert_load_window,
|
||||
eplb_stats.num_replicas,
|
||||
eplb_stats.num_groups,
|
||||
eplb_stats.num_nodes,
|
||||
eplb_stats.num_gpus,
|
||||
physical_to_logical_map_cpu,
|
||||
)
|
||||
assert new_physical_to_logical_map.device == torch.device("cpu")
|
||||
|
||||
model_state.new_physical_to_logical_map = new_physical_to_logical_map
|
||||
|
||||
max_slots = model_state.logical_to_physical_map.shape[-1]
|
||||
padded_logical = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
|
||||
value=-1,
|
||||
).to(model_state.logical_to_physical_map.device)
|
||||
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
|
||||
model_state.new_logical_to_physical_map = padded_logical
|
||||
model_state.new_logical_replica_count = new_replica
|
||||
|
||||
|
||||
async def transfer_run_periodically(
|
||||
state: "EplbState",
|
||||
ep_group: ProcessGroup,
|
||||
eplb_group: ProcessGroup,
|
||||
cuda_stream: torch.cuda.Stream,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
@@ -71,23 +115,51 @@ async def transfer_run_periodically(
|
||||
|
||||
assert state.is_async
|
||||
for model_state in state.model_states.values():
|
||||
rebalancing_algorithm_executed = False
|
||||
physical_to_logical_map_cpu = None
|
||||
current_num_layers = model_state.model.num_moe_layers
|
||||
while (
|
||||
model_state.rebalanced
|
||||
and model_state.layer_to_transfer < current_num_layers
|
||||
):
|
||||
if (
|
||||
not model_state.ep_buffer_ready
|
||||
and model_state.rebalanced
|
||||
and model_state.new_physical_to_logical_map is not None
|
||||
):
|
||||
await asyncio.to_thread(model_state.buffer_lock.acquire)
|
||||
if not model_state.ep_buffer_ready and model_state.rebalanced:
|
||||
# Polling the lock directly in the async thread avoids
|
||||
# the thread switch overhead of asyncio.to_thread.
|
||||
# This is typically faster than offloading to a worker thread.
|
||||
while not model_state.buffer_lock.acquire(blocking=False):
|
||||
await asyncio.sleep(0)
|
||||
try:
|
||||
if model_state.layer_to_transfer >= current_num_layers:
|
||||
break
|
||||
if (
|
||||
not rebalancing_algorithm_executed
|
||||
or model_state.new_physical_to_logical_map is None
|
||||
):
|
||||
# Move the physical_to_logical_map to CPU
|
||||
# for rebalancing and transfer_layer.
|
||||
physical_to_logical_map_cpu = (
|
||||
model_state.physical_to_logical_map.cpu()
|
||||
)
|
||||
run_rebalance_experts(
|
||||
model_state, state, physical_to_logical_map_cpu
|
||||
)
|
||||
rebalancing_algorithm_executed = True
|
||||
logger.info(
|
||||
"Async worker computed new indices for model %s",
|
||||
model_state.model_name,
|
||||
)
|
||||
|
||||
assert model_state.new_physical_to_logical_map is not None
|
||||
assert physical_to_logical_map_cpu is not None
|
||||
|
||||
layer_idx = model_state.layer_to_transfer
|
||||
old_layer_indices = physical_to_logical_map_cpu[layer_idx]
|
||||
new_layer_indices = model_state.new_physical_to_logical_map[
|
||||
layer_idx
|
||||
]
|
||||
|
||||
# Wait for the main thread to finish consuming the buffer
|
||||
# before overwriting it
|
||||
# before initiating an EPLB transfer on another layer.
|
||||
if model_state.buffer_consumed_event is not None:
|
||||
cuda_stream.wait_event(model_state.buffer_consumed_event)
|
||||
model_state.buffer_consumed_event = None
|
||||
@@ -97,13 +169,12 @@ async def transfer_run_periodically(
|
||||
model_state.is_received_locally,
|
||||
model_state.recv_metadata,
|
||||
) = await transfer_layer(
|
||||
old_global_expert_indices=model_state.physical_to_logical_map,
|
||||
new_global_expert_indices=model_state.new_physical_to_logical_map,
|
||||
expert_weights=model_state.model.expert_weights,
|
||||
old_layer_indices=old_layer_indices,
|
||||
new_layer_indices=new_layer_indices,
|
||||
expert_weights=model_state.model.expert_weights[layer_idx],
|
||||
expert_weights_buffer=model_state.expert_buffer,
|
||||
ep_group=ep_group,
|
||||
ep_group=eplb_group,
|
||||
is_profile=is_profile,
|
||||
layer=model_state.layer_to_transfer,
|
||||
cuda_stream=cuda_stream,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
|
||||
@@ -55,6 +55,35 @@ from .rebalance_execute import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EplbStats:
|
||||
"""
|
||||
Model stats used in EPLB rebalancing algorithm.
|
||||
"""
|
||||
|
||||
global_expert_load_window: torch.Tensor
|
||||
"""
|
||||
Experts load window.
|
||||
Shape: (window_size, num_moe_layers, num_physical_experts)
|
||||
"""
|
||||
num_replicas: int
|
||||
"""
|
||||
Number of physical experts.
|
||||
"""
|
||||
num_groups: int
|
||||
"""
|
||||
Number of expert groups.
|
||||
"""
|
||||
num_nodes: int
|
||||
"""
|
||||
Number of nodes.
|
||||
"""
|
||||
num_gpus: int
|
||||
"""
|
||||
Number of GPUs.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class EplbModelState:
|
||||
"""EPLB metrics."""
|
||||
@@ -156,6 +185,11 @@ class EplbModelState:
|
||||
CUDA event recorded after the main thread finishes consuming the buffer.
|
||||
The async worker waits on this before writing to the buffer again.
|
||||
"""
|
||||
window_ready_event: torch.cuda.Event | None
|
||||
"""
|
||||
CUDA event recorded after all-reduce and clone on the main thread.
|
||||
The async worker waits on this before accessing global_expert_load_window.
|
||||
"""
|
||||
ep_buffer_ready: int
|
||||
"""
|
||||
The flag indicates whether the expert buffer is ready for transfer.
|
||||
@@ -173,6 +207,10 @@ class EplbModelState:
|
||||
"""
|
||||
Whether the async EPLB needs to poll peers for buffer readiness.
|
||||
"""
|
||||
eplb_stats: EplbStats | None
|
||||
"""
|
||||
EPLB stats for the model.
|
||||
"""
|
||||
is_unchanged: np.ndarray
|
||||
"""
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
@@ -508,10 +546,12 @@ class EplbState:
|
||||
buffer_lock=threading.Lock(),
|
||||
buffer_ready_event=None,
|
||||
buffer_consumed_event=None,
|
||||
window_ready_event=None,
|
||||
ep_buffer_ready=0,
|
||||
layer_to_transfer=0,
|
||||
rebalanced=False,
|
||||
pending_global_ready_check=False,
|
||||
eplb_stats=None,
|
||||
is_unchanged=np.array([]),
|
||||
is_received_locally=np.array([]),
|
||||
recv_metadata=RecvMetadata(
|
||||
@@ -642,20 +682,6 @@ class EplbState:
|
||||
ep_group=ep_group,
|
||||
is_profile=is_profile,
|
||||
)
|
||||
if (
|
||||
eplb_model_state.layer_to_transfer
|
||||
>= eplb_model_state.model.num_moe_layers
|
||||
):
|
||||
self.post_eplb(eplb_model_state, is_profile)
|
||||
eplb_model_state.rebalanced = False
|
||||
eplb_model_state.layer_to_transfer = 0
|
||||
eplb_model_state.pending_global_ready_check = False
|
||||
logger.info(
|
||||
"finish async transfer for model %s rank %d layer %d",
|
||||
eplb_model_state.model_name,
|
||||
ep_group.rank(),
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
)
|
||||
|
||||
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
|
||||
if self.is_async and any(
|
||||
@@ -802,21 +828,21 @@ class EplbState:
|
||||
for eplb_model_state, global_expert_load_window in zip(
|
||||
self.model_states.values(), global_expert_load_windows
|
||||
):
|
||||
# Get new expert mappings for the model
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = self.policy.rebalance_experts(
|
||||
global_expert_load_window,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
)
|
||||
|
||||
if not self.is_async or is_profile:
|
||||
# Get new expert mappings for the model
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = self.policy.rebalance_experts(
|
||||
global_expert_load_window,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
)
|
||||
|
||||
# Update expert weights
|
||||
rearrange_expert_weights_inplace(
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
@@ -873,27 +899,25 @@ class EplbState:
|
||||
gpu_elapsed,
|
||||
)
|
||||
else:
|
||||
max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
|
||||
padded_logical = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
|
||||
value=-1,
|
||||
).to(eplb_model_state.logical_to_physical_map.device)
|
||||
new_replica = new_logical_replica_count.to(
|
||||
eplb_model_state.logical_replica_count.device
|
||||
eplb_model_state.eplb_stats = EplbStats(
|
||||
# We copy the tensor to snapshot the global_expert_load_window
|
||||
# on the main thread so that async worker can access it safely
|
||||
# while the main thread is running.
|
||||
global_expert_load_window=global_expert_load_window.clone(),
|
||||
num_replicas=num_replicas,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
num_gpus=num_gpus,
|
||||
)
|
||||
|
||||
# Move map to cpu in advance
|
||||
eplb_model_state.new_physical_to_logical_map = (
|
||||
new_physical_to_logical_map.cpu()
|
||||
)
|
||||
eplb_model_state.new_logical_to_physical_map = padded_logical
|
||||
eplb_model_state.new_logical_replica_count = new_replica
|
||||
# Record event after clone to signal async worker
|
||||
# that load stats data is ready
|
||||
sync_event = torch.cuda.Event()
|
||||
sync_event.record()
|
||||
eplb_model_state.window_ready_event = sync_event
|
||||
|
||||
eplb_model_state.rebalanced = True
|
||||
eplb_model_state.layer_to_transfer = 0
|
||||
eplb_model_state.pending_global_ready_check = True
|
||||
|
||||
# Signal async thread to start transferring layers
|
||||
if self.is_async and (not is_profile):
|
||||
self.rearrange_event.set()
|
||||
@@ -925,11 +949,13 @@ class EplbState:
|
||||
|
||||
target_device = model_state.physical_to_logical_map.device
|
||||
new_physical = model_state.new_physical_to_logical_map
|
||||
# If the number of physical experts has changed, then the new map needs to
|
||||
# be copied synchronously to avoid a race condition with the async worker
|
||||
if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
|
||||
model_state.physical_to_logical_map = new_physical.to(target_device)
|
||||
else:
|
||||
model_state.physical_to_logical_map[layer].copy_(
|
||||
new_physical[layer].to(target_device)
|
||||
new_physical[layer].to(target_device, non_blocking=True)
|
||||
)
|
||||
|
||||
logical_device = model_state.logical_to_physical_map.device
|
||||
@@ -1004,11 +1030,9 @@ class EplbState:
|
||||
model_state.layer_to_transfer
|
||||
]
|
||||
expert_weights_buffer = model_state.expert_buffer
|
||||
new_indices = (
|
||||
model_state.new_physical_to_logical_map[model_state.layer_to_transfer]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
new_indices = model_state.new_physical_to_logical_map[
|
||||
model_state.layer_to_transfer
|
||||
].numpy()
|
||||
move_from_buffer(
|
||||
expert_weights=expert_weights,
|
||||
expert_weights_buffers=expert_weights_buffer,
|
||||
@@ -1019,7 +1043,7 @@ class EplbState:
|
||||
ep_rank=ep_group.rank(),
|
||||
)
|
||||
# Record event after consuming buffer to signal async thread
|
||||
# that it's safe to overwrite the buffer
|
||||
# that it's safe to overwrite the intermediate buffer
|
||||
consumed_event = torch.cuda.Event()
|
||||
consumed_event.record()
|
||||
model_state.buffer_consumed_event = consumed_event
|
||||
@@ -1034,6 +1058,18 @@ class EplbState:
|
||||
model_state.model_name,
|
||||
transferred_layer,
|
||||
)
|
||||
if model_state.layer_to_transfer >= model_state.model.num_moe_layers:
|
||||
self.post_eplb(model_state, is_profile)
|
||||
model_state.rebalanced = False
|
||||
model_state.layer_to_transfer = 0
|
||||
model_state.pending_global_ready_check = False
|
||||
logger.info(
|
||||
"finish async transfer for model %s rank %d layer %d",
|
||||
model_state.model_name,
|
||||
ep_group.rank(),
|
||||
model_state.model.num_moe_layers,
|
||||
)
|
||||
|
||||
finally:
|
||||
try:
|
||||
model_state.buffer_lock.release()
|
||||
@@ -1048,9 +1084,7 @@ class EplbState:
|
||||
assert model_state.new_physical_to_logical_map is not None
|
||||
assert model_state.new_logical_to_physical_map is not None
|
||||
assert model_state.new_logical_replica_count is not None
|
||||
if not is_profile:
|
||||
for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
|
||||
self._update_layer_mapping_from_new(model_state, layer_idx)
|
||||
|
||||
model_state.new_physical_to_logical_map = None
|
||||
model_state.new_logical_to_physical_map = None
|
||||
model_state.new_logical_replica_count = None
|
||||
|
||||
@@ -434,13 +434,12 @@ def move_from_buffer(
|
||||
|
||||
|
||||
async def transfer_layer(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||
old_layer_indices: torch.Tensor,
|
||||
new_layer_indices: torch.Tensor,
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
expert_weights_buffer: Sequence[torch.Tensor],
|
||||
ep_group: ProcessGroup,
|
||||
is_profile: bool = False,
|
||||
layer: int = 0,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> MoveToBufferResult:
|
||||
@@ -451,56 +450,64 @@ async def transfer_layer(
|
||||
while keys are physical.
|
||||
|
||||
Args:
|
||||
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
||||
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
|
||||
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
||||
For example, a linear layer may have up and down projection,
|
||||
so weight_count = 2. Each weight's hidden size can be different.
|
||||
old_layer_indices: Shape (num_physical_experts,).
|
||||
new_layer_indices: Shape (num_physical_experts,).
|
||||
expert_weights: Iterable of weight tensors for this layer, each with shape
|
||||
(num_local_physical_experts, hidden_size_i).
|
||||
For example, a linear layer may have up and down projection.
|
||||
expert_weights_buffer: Intermediate buffers (one per weight tensor).
|
||||
ep_group: The device process group for expert parallelism.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||
rank_mapping: Optional rank mapping for elastic expert parallelism.
|
||||
|
||||
Returns:
|
||||
is_unchanged (np.ndarray): (1, num_local_experts), True where expert
|
||||
is_unchanged (np.ndarray): (num_local_experts,), True where expert
|
||||
is left unchanged.
|
||||
is_received_locally (np.ndarray): (1, num_local_experts), True where expert
|
||||
is_received_locally (np.ndarray): (num_local_experts,), True where expert
|
||||
can be received locally.
|
||||
RecvMetadata: Metadata needed for completing remote weight transfers.
|
||||
"""
|
||||
ep_size = ep_group.size()
|
||||
if rank_mapping is not None:
|
||||
# Add a layer dimension for compatibility with mapping functions
|
||||
old_layer_indices_2d = old_layer_indices.unsqueeze(0)
|
||||
new_layer_indices_2d = new_layer_indices.unsqueeze(0)
|
||||
|
||||
if len(rank_mapping) == ep_group.size():
|
||||
# scale down
|
||||
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
|
||||
new_global_expert_indices,
|
||||
new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping(
|
||||
new_layer_indices_2d,
|
||||
rank_mapping,
|
||||
)
|
||||
else:
|
||||
# scale up
|
||||
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
|
||||
old_global_expert_indices,
|
||||
old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping(
|
||||
old_layer_indices_2d,
|
||||
rank_mapping,
|
||||
ep_group.size(),
|
||||
)
|
||||
|
||||
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
|
||||
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
|
||||
assert len(expert_weights) == num_moe_layers
|
||||
# Remove the layer dimension
|
||||
old_layer_indices = old_layer_indices_2d.squeeze(0)
|
||||
new_layer_indices = new_layer_indices_2d.squeeze(0)
|
||||
|
||||
assert old_layer_indices.shape == new_layer_indices.shape
|
||||
num_physical_experts = old_layer_indices.shape[0]
|
||||
assert len(expert_weights[0]) >= 1
|
||||
num_local_physical_experts = expert_weights[0][0].shape[0]
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
num_local_physical_experts = expert_weights[0].shape[0]
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
|
||||
old_global_expert_indices_np = old_global_expert_indices.cpu().numpy()
|
||||
new_global_expert_indices_np = new_global_expert_indices.cpu().numpy()
|
||||
old_layer_indices_np = old_layer_indices.cpu().numpy()
|
||||
new_layer_indices_np = new_layer_indices.cpu().numpy()
|
||||
|
||||
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
|
||||
num_local_experts=num_local_physical_experts,
|
||||
old_indices=old_global_expert_indices_np[layer],
|
||||
new_indices=new_global_expert_indices_np[layer],
|
||||
expert_weights=expert_weights[layer],
|
||||
old_indices=old_layer_indices_np,
|
||||
new_indices=new_layer_indices_np,
|
||||
expert_weights=expert_weights,
|
||||
expert_weights_buffers=expert_weights_buffer,
|
||||
cuda_stream=cuda_stream,
|
||||
ep_group=ep_group,
|
||||
|
||||
@@ -1143,6 +1143,18 @@ def get_ep_group() -> GroupCoordinator:
|
||||
return _EP
|
||||
|
||||
|
||||
_EPLB: GroupCoordinator | None = None
|
||||
|
||||
|
||||
def get_eplb_group() -> GroupCoordinator:
|
||||
assert _EPLB is not None, (
|
||||
"EPLB group is not initialized. "
|
||||
"EPLB group is only created for MoE models when EPLB is enabled. "
|
||||
"Ensure parallel_config.enable_eplb is True."
|
||||
)
|
||||
return _EPLB
|
||||
|
||||
|
||||
_PCP: GroupCoordinator | None = None
|
||||
|
||||
|
||||
@@ -1440,12 +1452,29 @@ def initialize_model_parallel(
|
||||
_EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep"
|
||||
)
|
||||
|
||||
# Create EPLB group with the same ranks as EP if EPLB is enabled.
|
||||
# This is a separate process group to isolate EPLB communications
|
||||
# from MoE forward pass collectives and prevent deadlocks when
|
||||
# using torch.distributed in execution with torch.distributed in EPLB.
|
||||
global _EPLB
|
||||
assert _EPLB is None, "EPLB group is already initialized"
|
||||
if (
|
||||
config is not None
|
||||
and config.parallel_config is not None
|
||||
and config.parallel_config.enable_eplb
|
||||
):
|
||||
# Reuse the same group_ranks from EP
|
||||
_EPLB = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
|
||||
)
|
||||
# If no EP group needed, _EP remains None
|
||||
# If no EPLB group needed, _EPLB remains None
|
||||
|
||||
logger.info_once(
|
||||
"rank %s in world size %s is assigned as "
|
||||
"DP rank %s, PP rank %s, PCP rank %s, "
|
||||
"TP rank %s, EP rank %s",
|
||||
"TP rank %s, EP rank %s, EPLB rank %s",
|
||||
rank,
|
||||
world_size,
|
||||
_DP.rank_in_group,
|
||||
@@ -1453,6 +1482,7 @@ def initialize_model_parallel(
|
||||
_PCP.rank_in_group,
|
||||
_TP.rank_in_group,
|
||||
_EP.rank_in_group if _EP is not None else "N/A",
|
||||
_EPLB.rank_in_group if _EPLB is not None else "N/A",
|
||||
)
|
||||
|
||||
|
||||
@@ -1514,6 +1544,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module):
|
||||
_DP.prepare_communication_buffer_for_model(model)
|
||||
if _EP is not None:
|
||||
_EP.prepare_communication_buffer_for_model(model)
|
||||
if _EPLB is not None:
|
||||
_EPLB.prepare_communication_buffer_for_model(model)
|
||||
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
@@ -1608,6 +1640,11 @@ def destroy_model_parallel():
|
||||
_EP.destroy()
|
||||
_EP = None
|
||||
|
||||
global _EPLB
|
||||
if _EPLB:
|
||||
_EPLB.destroy()
|
||||
_EPLB = None
|
||||
|
||||
|
||||
def destroy_distributed_environment():
|
||||
global _WORLD, _NODE_COUNT
|
||||
|
||||
Reference in New Issue
Block a user