[EPLB] Add alternative communication for EPLB weight exchange (#33176)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Markov Ilya <markovilya19@gmail.com> Co-authored-by: Markov Ilya <markovilya19@gmail.com>
This commit is contained in:
@@ -13,8 +13,8 @@ steps:
|
||||
- pytest -v -s distributed/test_eplb_algo.py
|
||||
- pytest -v -s distributed/test_eplb_utils.py
|
||||
|
||||
- label: EPLB Execution
|
||||
timeout_in_minutes: 20
|
||||
- label: EPLB Execution # 17min
|
||||
timeout_in_minutes: 27
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_devices: 4
|
||||
source_file_dependencies:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
@@ -16,9 +18,20 @@ from vllm.utils.system_utils import update_environment_variables
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _distributed_worker_wrapper(fn, env, world_size, args, rank, skip_queue):
|
||||
try:
|
||||
fn(env, world_size, *args)
|
||||
except BaseException as exc:
|
||||
if isinstance(exc, pytest.skip.Exception):
|
||||
skip_queue.put((rank, str(exc)))
|
||||
return
|
||||
raise
|
||||
|
||||
|
||||
def distributed_run(fn, world_size, *args):
|
||||
number_of_processes = world_size
|
||||
processes: list[mp.Process] = []
|
||||
skip_queue: mp.SimpleQueue = mp.SimpleQueue()
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env["RANK"] = str(i)
|
||||
@@ -27,13 +40,32 @@ def distributed_run(fn, world_size, *args):
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = mp.Process(target=fn, args=(env, world_size, *args))
|
||||
p = mp.Process(
|
||||
target=_distributed_worker_wrapper,
|
||||
args=(fn, env, world_size, args, i, skip_queue),
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
skipped: list[tuple[int, str]] = []
|
||||
while not skip_queue.empty():
|
||||
rank, reason = skip_queue.get()
|
||||
skipped.append((rank, reason))
|
||||
|
||||
if len(skipped) == number_of_processes:
|
||||
reason = skipped[0][1]
|
||||
pytest.skip(reason)
|
||||
if 0 < len(skipped) < number_of_processes:
|
||||
skipped_ranks = sorted(rank for rank, _ in skipped)
|
||||
raise AssertionError(
|
||||
"Distributed test had partial skips; expected either all ranks "
|
||||
f"to skip or none. Skipped ranks: {skipped_ranks}, "
|
||||
f"total ranks: {number_of_processes}"
|
||||
)
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
@@ -48,7 +80,12 @@ def set_env_vars_and_device(env: dict[str, str]) -> None:
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment()
|
||||
|
||||
atexit.register(_destroy_process_group_if_initialized)
|
||||
# Ensure each worker process has the same random seed
|
||||
random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def _destroy_process_group_if_initialized() -> None:
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
|
||||
from vllm.distributed.eplb.rebalance_execute import (
|
||||
move_from_buffer,
|
||||
rearrange_expert_weights_inplace,
|
||||
@@ -130,9 +131,10 @@ def verify_expert_weights_after_shuffle(
|
||||
hidden_sizes: list[int],
|
||||
ep_rank: int,
|
||||
num_local_experts: int,
|
||||
):
|
||||
) -> bool:
|
||||
"""Verify the weights after shuffling are correct."""
|
||||
num_layers = len(expert_weights)
|
||||
ok = True
|
||||
|
||||
for layer in range(num_layers):
|
||||
for weight_idx, hidden_size in enumerate(hidden_sizes):
|
||||
@@ -155,29 +157,38 @@ def verify_expert_weights_after_shuffle(
|
||||
dtype=actual_weights.dtype,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
actual_weights,
|
||||
expected_weights,
|
||||
msg=f"Layer {layer}, weight {weight_idx},"
|
||||
f"local expert {local_expert}: "
|
||||
f"weights do not match. "
|
||||
f"Expected logical expert {expected_logical_expert}",
|
||||
)
|
||||
if not torch.equal(actual_weights, expected_weights):
|
||||
ok = False
|
||||
actual_head = actual_weights[:8].detach().cpu().tolist()
|
||||
expected_head = expected_weights[:8].detach().cpu().tolist()
|
||||
print(
|
||||
"verify_expert_weights_after_shuffle failed: "
|
||||
f"rank={ep_rank}, "
|
||||
f"layer={layer}, weight_idx={weight_idx}, "
|
||||
f"local_expert={local_expert}, "
|
||||
f"expected_logical_expert={expected_logical_expert}, "
|
||||
f"actual_head={actual_head}, expected_head={expected_head}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def verify_redundant_experts_have_same_weights(
|
||||
expert_weights: list[list[torch.Tensor]],
|
||||
indices: torch.Tensor,
|
||||
hidden_sizes: list[int],
|
||||
ep_rank: int,
|
||||
world_size: int,
|
||||
num_local_experts: int,
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that all replicas of the same logical expert have the same weights.
|
||||
"""
|
||||
num_layers = len(expert_weights)
|
||||
total_physical_experts = world_size * num_local_experts
|
||||
|
||||
ok = True
|
||||
for layer in range(num_layers):
|
||||
# Collect weights for all physical experts for each weight matrix
|
||||
all_weights: list[torch.Tensor] = []
|
||||
@@ -227,14 +238,54 @@ def verify_redundant_experts_have_same_weights(
|
||||
# Verify that current physical expert's weights match the
|
||||
# previously saved logical expert weights
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
if not torch.equal(
|
||||
all_weights[weight_idx][physical_pos],
|
||||
logical_expert_weights[logical_expert_id][weight_idx],
|
||||
msg=f"Layer {layer}, weight {weight_idx},"
|
||||
f"logical expert {logical_expert_id}: "
|
||||
f"Physical expert {physical_pos} has different weights"
|
||||
f"than expected",
|
||||
)
|
||||
):
|
||||
ok = False
|
||||
actual_head = (
|
||||
all_weights[weight_idx][physical_pos][:8]
|
||||
.detach()
|
||||
.cpu()
|
||||
.tolist()
|
||||
)
|
||||
reference_head = (
|
||||
logical_expert_weights[logical_expert_id][weight_idx][:8]
|
||||
.detach()
|
||||
.cpu()
|
||||
.tolist()
|
||||
)
|
||||
print(
|
||||
"verify_redundant_experts_have_same_weights failed: "
|
||||
f"rank={ep_rank}, "
|
||||
f"layer={layer}, weight_idx={weight_idx}, "
|
||||
f"logical_expert={logical_expert_id}, "
|
||||
f"physical_pos={physical_pos}, "
|
||||
f"actual_head={actual_head}, "
|
||||
f"reference_head={reference_head}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def assert_verification_synced(local_ok: bool, msg: str) -> None:
|
||||
ok_tensor = torch.tensor([1 if local_ok else 0], device="cuda", dtype=torch.int32)
|
||||
torch.distributed.all_reduce(ok_tensor, op=torch.distributed.ReduceOp.MIN)
|
||||
assert bool(ok_tensor.item()), msg
|
||||
|
||||
|
||||
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights):
|
||||
try:
|
||||
return create_eplb_communicator(
|
||||
group_coordinator=group_coordinator,
|
||||
backend=backend,
|
||||
expert_weights=expert_weights,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to create EPLB communicator for backend={backend}: {exc}"
|
||||
) from exc
|
||||
|
||||
|
||||
def _test_async_transfer_layer_without_mtp_worker(
|
||||
@@ -243,6 +294,7 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
num_layers: int,
|
||||
num_local_experts: int,
|
||||
num_logical_experts: int,
|
||||
eplb_communicator: str,
|
||||
) -> None:
|
||||
set_env_vars_and_device(env)
|
||||
|
||||
@@ -254,8 +306,8 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
tp_group = get_tp_group()
|
||||
ep_group = tp_group.device_group
|
||||
ep_group_coordinator = get_tp_group()
|
||||
ep_group = ep_group_coordinator.device_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
@@ -298,6 +350,13 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
cuda_stream = torch.cuda.Stream(device=device)
|
||||
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend=eplb_communicator,
|
||||
expert_weights=expert_weights[0],
|
||||
)
|
||||
communicator.set_stream(cuda_stream)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
|
||||
transfer_layer(
|
||||
@@ -306,6 +365,7 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffer=expert_buffer,
|
||||
ep_group=ep_group,
|
||||
communicator=communicator,
|
||||
cuda_stream=cuda_stream,
|
||||
)
|
||||
)
|
||||
@@ -320,24 +380,38 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
ep_rank=ep_rank,
|
||||
)
|
||||
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
local_ok = verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
local_ok = (
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
and local_ok
|
||||
)
|
||||
assert_verification_synced(
|
||||
local_ok,
|
||||
"Async transfer verification failed on at least one rank. "
|
||||
"See logs for details.",
|
||||
)
|
||||
|
||||
|
||||
def _test_rearrange_expert_weights_with_redundancy(
|
||||
env, world_size, num_layers, num_local_experts, num_logical_experts
|
||||
env,
|
||||
world_size,
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
eplb_communicator: str,
|
||||
) -> None:
|
||||
# Initialize model parallel (using tensor parallel as an entrypoint
|
||||
# to expert parallel)
|
||||
@@ -351,7 +425,8 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_group_coordinator = get_tp_group()
|
||||
ep_group = ep_group_coordinator.cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
@@ -387,6 +462,12 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend=eplb_communicator,
|
||||
expert_weights=expert_weights[0],
|
||||
)
|
||||
|
||||
# Execute weight rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
@@ -394,24 +475,33 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
expert_weights,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
communicator=communicator,
|
||||
)
|
||||
|
||||
# Verify the rearrangement result
|
||||
verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
# Verify the rearrangement result
|
||||
local_ok = verify_expert_weights_after_shuffle(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
num_local_experts,
|
||||
)
|
||||
|
||||
local_ok = (
|
||||
verify_redundant_experts_have_same_weights(
|
||||
expert_weights,
|
||||
new_indices,
|
||||
hidden_sizes,
|
||||
ep_rank,
|
||||
world_size,
|
||||
num_local_experts,
|
||||
)
|
||||
and local_ok
|
||||
)
|
||||
assert_verification_synced(
|
||||
local_ok,
|
||||
"Rearrange verification failed on at least one rank. See logs for details.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -437,8 +527,13 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
(4, 8, 8, 16),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"])
|
||||
def test_rearrange_expert_weights_with_redundancy(
|
||||
world_size, num_layers, num_local_experts, num_logical_experts
|
||||
world_size,
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
eplb_communicator,
|
||||
):
|
||||
"""Test the functionality of rearranging expert weights with redundancy."""
|
||||
|
||||
@@ -450,6 +545,7 @@ def test_rearrange_expert_weights_with_redundancy(
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
eplb_communicator,
|
||||
)
|
||||
|
||||
|
||||
@@ -464,7 +560,8 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_group_coordinator = get_tp_group()
|
||||
ep_group = ep_group_coordinator.cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
@@ -494,24 +591,40 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend="torch_nccl",
|
||||
expert_weights=expert_weights[0],
|
||||
)
|
||||
|
||||
# Execute rearrangement (should be no change)
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
indices, # Same indices
|
||||
expert_weights,
|
||||
ep_group,
|
||||
communicator,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
# Verify that the weights have not changed
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg=f"""Layer {layer}, weight {weight_idx}
|
||||
should remain unchanged""",
|
||||
# Verify that the weights have not changed
|
||||
local_ok = True
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
if not torch.equal(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
):
|
||||
local_ok = False
|
||||
print(
|
||||
"test_rearrange_expert_weights_no_change failed: "
|
||||
f"layer={layer}, weight_idx={weight_idx}",
|
||||
flush=True,
|
||||
)
|
||||
assert_verification_synced(
|
||||
local_ok,
|
||||
"No-change EPLB verification failed on at least one rank.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -520,11 +633,13 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
(2, 2, 2, 3),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("eplb_communicator", ["torch_nccl", "torch_gloo", "pynccl"])
|
||||
def test_async_transfer_layer_without_mtp(
|
||||
world_size: int,
|
||||
num_layers: int,
|
||||
num_local_experts: int,
|
||||
num_logical_experts: int,
|
||||
eplb_communicator: str,
|
||||
):
|
||||
"""Exercise async EPLB transfer path without MTP/spec decode."""
|
||||
|
||||
@@ -537,6 +652,7 @@ def test_async_transfer_layer_without_mtp(
|
||||
num_layers,
|
||||
num_local_experts,
|
||||
num_logical_experts,
|
||||
eplb_communicator,
|
||||
)
|
||||
|
||||
|
||||
@@ -549,7 +665,10 @@ def test_rearrange_expert_weights_no_change(world_size):
|
||||
|
||||
if torch.accelerator.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
distributed_run(_test_rearrange_expert_weights_no_change, world_size)
|
||||
distributed_run(
|
||||
_test_rearrange_expert_weights_no_change,
|
||||
world_size,
|
||||
)
|
||||
|
||||
|
||||
def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
@@ -563,7 +682,8 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1
|
||||
)
|
||||
|
||||
ep_group = get_tp_group().cpu_group
|
||||
ep_group_coordinator = get_tp_group()
|
||||
ep_group = ep_group_coordinator.cpu_group
|
||||
ep_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{ep_rank}")
|
||||
|
||||
@@ -600,23 +720,40 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend="torch_nccl",
|
||||
expert_weights=expert_weights[0],
|
||||
)
|
||||
|
||||
# Execute profile mode rearrangement
|
||||
rearrange_expert_weights_inplace(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
ep_group,
|
||||
communicator,
|
||||
is_profile=True, # Profile mode
|
||||
)
|
||||
|
||||
# In profile mode, the weights should remain unchanged
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
torch.testing.assert_close(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
msg="In profile mode, the weights should remain unchanged",
|
||||
# In profile mode, the weights should remain unchanged
|
||||
local_ok = True
|
||||
for layer in range(num_layers):
|
||||
for weight_idx in range(len(hidden_sizes)):
|
||||
if not torch.equal(
|
||||
expert_weights[layer][weight_idx],
|
||||
original_weights[layer][weight_idx],
|
||||
):
|
||||
local_ok = False
|
||||
print(
|
||||
"test_rearrange_expert_weights_profile_mode failed: "
|
||||
f"layer={layer}, weight_idx={weight_idx}",
|
||||
flush=True,
|
||||
)
|
||||
assert_verification_synced(
|
||||
local_ok,
|
||||
"Profile-mode EPLB verification failed on at least one rank.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
@@ -625,4 +762,7 @@ def test_rearrange_expert_weights_profile_mode(world_size):
|
||||
|
||||
if torch.accelerator.device_count() < world_size:
|
||||
pytest.skip(f"Need at least {world_size} GPUs to run the test")
|
||||
distributed_run(_test_rearrange_expert_weights_profile_mode, world_size)
|
||||
distributed_run(
|
||||
_test_rearrange_expert_weights_profile_mode,
|
||||
world_size,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,7 @@ DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
||||
DataParallelBackend = Literal["ray", "mp"]
|
||||
EPLBPolicyOption = Literal["default"]
|
||||
DCPCommBackend = Literal["ag_rs", "a2a"]
|
||||
EPLBCommunicatorBackend = Literal["torch_nccl", "torch_gloo", "pynccl"]
|
||||
All2AllBackend = Literal[
|
||||
"naive",
|
||||
"pplx",
|
||||
@@ -83,6 +84,15 @@ class EPLBConfig:
|
||||
policy: EPLBPolicyOption = "default"
|
||||
"""The policy type for expert parallel load balancing (EPLB)."""
|
||||
|
||||
communicator: EPLBCommunicatorBackend | None = None
|
||||
"""
|
||||
Backend for EPLB expert weight communication:
|
||||
- "torch_nccl": Use torch.distributed on the device process group
|
||||
- "torch_gloo": Use torch.distributed gloo with CPU staging
|
||||
- "pynccl": Use PyNccl send/recv
|
||||
- None: Auto-select backend ("torch_gloo" for async, "torch_nccl" for sync)
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_eplb_config(self) -> Self:
|
||||
if self.use_async and self.policy != "default":
|
||||
@@ -764,16 +774,18 @@ class ParallelConfig:
|
||||
"backend is mp, uni or external_launcher."
|
||||
)
|
||||
|
||||
if (
|
||||
self.all2all_backend in ("allgather_reducescatter")
|
||||
and self.eplb_config.use_async
|
||||
):
|
||||
logger.warning(
|
||||
"Async EPLB causes hangs with the '%s' all2all backend. "
|
||||
"Forcing synchronous EPLB.",
|
||||
self.all2all_backend,
|
||||
)
|
||||
self.eplb_config.use_async = False
|
||||
if self.enable_eplb and self.eplb_config.communicator is None:
|
||||
if self.enable_elastic_ep:
|
||||
# Elastic EP requires stateless mode
|
||||
# (torch.distributed.batch_isend_irecv doesn't
|
||||
# support stateless mode), so we use PyNCCL backend
|
||||
self.eplb_config.communicator = "pynccl"
|
||||
elif self.eplb_config.use_async:
|
||||
# Torch Gloo is a backend that allows avoiding hangs
|
||||
# due to NCCL multi-thread conflicts in async EPLB
|
||||
self.eplb_config.communicator = "torch_gloo"
|
||||
else:
|
||||
self.eplb_config.communicator = "torch_nccl"
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import functools
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
@@ -75,26 +76,34 @@ class ncclDataTypeEnum:
|
||||
ncclFloat8e4m3 = 10
|
||||
ncclNumTypes = 11
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _torch_to_nccl_map(cls) -> dict[torch.dtype, int]:
|
||||
return {
|
||||
torch.int8: cls.ncclInt8,
|
||||
torch.uint8: cls.ncclUint8,
|
||||
torch.int32: cls.ncclInt32,
|
||||
torch.int64: cls.ncclInt64,
|
||||
torch.float16: cls.ncclFloat16,
|
||||
torch.float32: cls.ncclFloat32,
|
||||
torch.float64: cls.ncclFloat64,
|
||||
torch.bfloat16: cls.ncclBfloat16,
|
||||
current_platform.fp8_dtype(): cls.ncclFloat8e4m3,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def supports_torch_dtype(cls, dtype: torch.dtype) -> bool:
|
||||
return dtype in cls._torch_to_nccl_map()
|
||||
|
||||
@classmethod
|
||||
def try_from_torch(cls, dtype: torch.dtype) -> int | None:
|
||||
return cls._torch_to_nccl_map().get(dtype)
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
if dtype == current_platform.fp8_dtype():
|
||||
return cls.ncclFloat8e4m3
|
||||
nccl_dtype = cls.try_from_torch(dtype)
|
||||
if nccl_dtype is not None:
|
||||
return nccl_dtype
|
||||
raise ValueError(
|
||||
f"Unsupported dtype {dtype}: should be one of "
|
||||
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
|
||||
|
||||
@@ -29,8 +29,10 @@ from vllm.distributed.elastic_ep.standby_state import (
|
||||
get_standby_ep_group,
|
||||
pop_standby_groups,
|
||||
)
|
||||
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
|
||||
from vllm.distributed.parallel_state import (
|
||||
_replace_active_groups,
|
||||
get_eplb_group,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
@@ -411,6 +413,13 @@ class ElasticEPScalingExecutor:
|
||||
module.quant_method = module.quant_method.old_quant_method
|
||||
module.runner = module._init_runner()
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
||||
|
||||
eplb_model_state.communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend=parallel_config.eplb_config.communicator,
|
||||
expert_weights=model.expert_weights[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.worker.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
|
||||
@@ -98,6 +98,8 @@ async def transfer_run_periodically(
|
||||
|
||||
assert state.is_async
|
||||
for model_state in state.model_states.values():
|
||||
# Set the async worker's CUDA stream on the communicator
|
||||
model_state.communicator.set_stream(cuda_stream)
|
||||
rebalancing_algorithm_executed = False
|
||||
physical_to_logical_map_cpu = None
|
||||
current_num_layers = model_state.model.num_moe_layers
|
||||
@@ -157,6 +159,7 @@ async def transfer_run_periodically(
|
||||
expert_weights=model_state.model.expert_weights[layer_idx],
|
||||
expert_weights_buffer=model_state.expert_buffer,
|
||||
ep_group=eplb_group,
|
||||
communicator=model_state.communicator,
|
||||
is_profile=is_profile,
|
||||
cuda_stream=cuda_stream,
|
||||
)
|
||||
|
||||
277
vllm/distributed/eplb/eplb_communicator.py
Normal file
277
vllm/distributed/eplb/eplb_communicator.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
EPLB communicator implementations and factory.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
from torch.distributed import (
|
||||
P2POp,
|
||||
ProcessGroup,
|
||||
batch_isend_irecv,
|
||||
)
|
||||
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
ncclDataTypeEnum,
|
||||
)
|
||||
from vllm.distributed.parallel_state import GroupCoordinator, is_local_first_rank
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EplbCommunicator(ABC):
|
||||
"""Abstract EPLB communicator for expert weight transfers."""
|
||||
|
||||
@abstractmethod
|
||||
def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute(self) -> None:
|
||||
pass
|
||||
|
||||
def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None:
|
||||
self._cuda_stream = cuda_stream
|
||||
|
||||
def _log_initialized(self) -> None:
|
||||
if is_local_first_rank():
|
||||
logger.info("Initialized EPLB communicator: %s.", self.__class__.__name__)
|
||||
|
||||
|
||||
class TorchDistNcclEplbCommunicator(EplbCommunicator):
|
||||
"""EPLB communicator backed by torch.distributed isend/irecv."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ep_group: ProcessGroup,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
) -> None:
|
||||
self._ep_group = ep_group
|
||||
self._cuda_stream = cuda_stream
|
||||
self._p2p_ops: list[P2POp] = []
|
||||
self._log_initialized()
|
||||
|
||||
def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
|
||||
self._p2p_ops.append(
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
tensor,
|
||||
dst_rank,
|
||||
self._ep_group,
|
||||
)
|
||||
)
|
||||
|
||||
def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
|
||||
self._p2p_ops.append(
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
tensor,
|
||||
src_rank,
|
||||
self._ep_group,
|
||||
)
|
||||
)
|
||||
|
||||
def execute(self) -> None:
|
||||
if not self._p2p_ops:
|
||||
return
|
||||
try:
|
||||
with torch.cuda.stream(self._cuda_stream):
|
||||
reqs = batch_isend_irecv(self._p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
finally:
|
||||
self._p2p_ops.clear()
|
||||
|
||||
|
||||
class TorchDistGlooStagedEplbCommunicator(EplbCommunicator):
|
||||
"""EPLB communicator using gloo P2P with CPU staging."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
) -> None:
|
||||
self._cpu_group = cpu_group
|
||||
self._cuda_stream = cuda_stream
|
||||
self._ops: list[tuple[str, torch.Tensor, int]] = []
|
||||
self._log_initialized()
|
||||
|
||||
def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
|
||||
self._ops.append(("send", tensor, dst_rank))
|
||||
|
||||
def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
|
||||
self._ops.append(("recv", tensor, src_rank))
|
||||
|
||||
def execute(self) -> None:
|
||||
if not self._ops:
|
||||
return
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
recv_staging: list[tuple[torch.Tensor, torch.Tensor]] = []
|
||||
|
||||
def build_ops() -> None:
|
||||
for op, tensor, peer_rank in self._ops:
|
||||
if op == "send":
|
||||
cpu_tensor = tensor.to(device="cpu", non_blocking=True)
|
||||
p2p_ops.append(
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
cpu_tensor,
|
||||
peer_rank,
|
||||
self._cpu_group,
|
||||
)
|
||||
)
|
||||
continue
|
||||
cpu_tensor = torch.empty_like(tensor, device="cpu")
|
||||
p2p_ops.append(
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
cpu_tensor,
|
||||
peer_rank,
|
||||
self._cpu_group,
|
||||
)
|
||||
)
|
||||
recv_staging.append((tensor, cpu_tensor))
|
||||
|
||||
try:
|
||||
with torch.cuda.stream(self._cuda_stream):
|
||||
build_ops()
|
||||
finally:
|
||||
self._ops.clear()
|
||||
|
||||
# Wait for all D2H copies to finish
|
||||
# before issuing gloo batch_isend_irecv operations.
|
||||
if self._cuda_stream is not None:
|
||||
self._cuda_stream.synchronize()
|
||||
else:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
if not recv_staging:
|
||||
return
|
||||
with torch.cuda.stream(self._cuda_stream):
|
||||
for dst_tensor, cpu_tensor in recv_staging:
|
||||
dst_tensor.copy_(cpu_tensor, non_blocking=True)
|
||||
|
||||
|
||||
class PyNcclEplbCommunicator(EplbCommunicator):
|
||||
"""EPLB communicator backed by PyNcclCommunicator using ncclSend/ncclRecv."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pynccl_comm: PyNcclCommunicator,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
) -> None:
|
||||
self._pynccl_comm = pynccl_comm
|
||||
self._cuda_stream = cuda_stream
|
||||
self._group_started = False
|
||||
self._log_initialized()
|
||||
|
||||
def _ensure_group_started(self) -> None:
|
||||
if not self._group_started:
|
||||
self._pynccl_comm.group_start()
|
||||
self._group_started = True
|
||||
|
||||
def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None:
|
||||
self._ensure_group_started()
|
||||
self._pynccl_comm.send(tensor, dst_rank, stream=self._cuda_stream)
|
||||
|
||||
def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None:
|
||||
self._ensure_group_started()
|
||||
self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream)
|
||||
|
||||
def execute(self) -> None:
|
||||
if self._group_started:
|
||||
self._pynccl_comm.group_end()
|
||||
self._group_started = False
|
||||
|
||||
|
||||
def create_eplb_communicator(
|
||||
group_coordinator: GroupCoordinator,
|
||||
backend: str | None,
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
) -> EplbCommunicator:
|
||||
# Keep a safe default for callers that have not resolved communicator yet.
|
||||
if backend is None:
|
||||
backend = "torch_nccl"
|
||||
|
||||
tensor_device_type = expert_weights[0].device.type if expert_weights else "cpu"
|
||||
torch_group = (
|
||||
group_coordinator.cpu_group
|
||||
if tensor_device_type == "cpu"
|
||||
else group_coordinator.device_group
|
||||
)
|
||||
|
||||
def _create_pynccl() -> EplbCommunicator:
|
||||
if tensor_device_type == "cpu":
|
||||
raise RuntimeError(
|
||||
"EPLB communicator 'pynccl' supports only cuda-like devices "
|
||||
f"(got {tensor_device_type})."
|
||||
)
|
||||
unsupported_dtypes = sorted(
|
||||
{
|
||||
tensor.dtype
|
||||
for tensor in expert_weights
|
||||
if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype)
|
||||
},
|
||||
key=str,
|
||||
)
|
||||
if unsupported_dtypes:
|
||||
raise RuntimeError(
|
||||
"EPLB communicator 'pynccl' requested but expert weights contain "
|
||||
"unsupported dtypes: "
|
||||
f"({', '.join(str(dtype) for dtype in unsupported_dtypes)})."
|
||||
)
|
||||
|
||||
device_comm = group_coordinator.device_communicator
|
||||
pynccl_comm = (
|
||||
getattr(device_comm, "pynccl_comm", None)
|
||||
if device_comm is not None
|
||||
else None
|
||||
)
|
||||
if pynccl_comm is None or pynccl_comm.disabled or not pynccl_comm.available:
|
||||
raise RuntimeError("EPLB communicator 'pynccl' requested but unavailable.")
|
||||
try:
|
||||
return PyNcclEplbCommunicator(pynccl_comm=pynccl_comm)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to initialize PyNcclEplbCommunicator ({exc})."
|
||||
) from exc
|
||||
|
||||
is_stateless = isinstance(group_coordinator, StatelessGroupCoordinator)
|
||||
if is_stateless:
|
||||
if backend not in ("torch_nccl", "pynccl"):
|
||||
raise ValueError(
|
||||
f"Elastic EP requires 'torch_nccl' or 'pynccl' EPLB communicator "
|
||||
f"(got '{backend}'). torch_gloo is not supported with stateless groups."
|
||||
)
|
||||
if backend == "torch_nccl":
|
||||
logger.warning(
|
||||
"Stateless elastic EP requires PyNCCL backend. "
|
||||
"Forcing EPLB communicator to 'pynccl'."
|
||||
)
|
||||
backend = "pynccl"
|
||||
return _create_pynccl()
|
||||
|
||||
if backend == "torch_gloo":
|
||||
return TorchDistGlooStagedEplbCommunicator(
|
||||
cpu_group=group_coordinator.cpu_group,
|
||||
)
|
||||
elif backend == "torch_nccl":
|
||||
return TorchDistNcclEplbCommunicator(ep_group=torch_group)
|
||||
elif backend == "pynccl":
|
||||
return _create_pynccl()
|
||||
raise ValueError(f"Unknown EPLB communicator backend: {backend}")
|
||||
@@ -37,6 +37,7 @@ from torch.distributed import ProcessGroup, all_reduce
|
||||
from vllm.config import ModelConfig, ParallelConfig
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_ep_group,
|
||||
get_eplb_group,
|
||||
get_node_count,
|
||||
in_the_same_node_as,
|
||||
)
|
||||
@@ -46,6 +47,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
|
||||
from .async_worker import start_async_worker
|
||||
from .eplb_communicator import EplbCommunicator, create_eplb_communicator
|
||||
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
|
||||
from .rebalance_execute import (
|
||||
RecvMetadata,
|
||||
@@ -225,6 +227,10 @@ class EplbModelState:
|
||||
"""
|
||||
CUDA device index for the async EPLB worker thread.
|
||||
"""
|
||||
communicator: EplbCommunicator
|
||||
"""
|
||||
The communicator for expert weight transfers.
|
||||
"""
|
||||
new_physical_to_logical_map: torch.Tensor | None = None
|
||||
"""
|
||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||
@@ -472,6 +478,12 @@ class EplbState:
|
||||
self._init_should_record_tensor(model)
|
||||
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
|
||||
|
||||
communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend=self.parallel_config.eplb_config.communicator,
|
||||
expert_weights=model.expert_weights[0],
|
||||
)
|
||||
|
||||
model_state = EplbModelState(
|
||||
physical_to_logical_map=physical_to_logical_map,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
@@ -498,6 +510,7 @@ class EplbState:
|
||||
recv_dst_rows=np.array([]),
|
||||
),
|
||||
cuda_device_index=self.cuda_device_index,
|
||||
communicator=communicator,
|
||||
new_physical_to_logical_map=None,
|
||||
)
|
||||
self.model_states[model_config.compute_hash()] = model_state
|
||||
@@ -800,6 +813,7 @@ class EplbState:
|
||||
new_physical_to_logical_map,
|
||||
eplb_model_state.model.expert_weights,
|
||||
ep_group,
|
||||
eplb_model_state.communicator,
|
||||
is_profile,
|
||||
rank_mapping,
|
||||
)
|
||||
@@ -923,11 +937,8 @@ class EplbState:
|
||||
new_indices=new_indices,
|
||||
ep_rank=ep_group.rank(),
|
||||
)
|
||||
# Record event after consuming buffer to signal async thread
|
||||
# that it's safe to overwrite the intermediate buffer
|
||||
consumed_event = torch.cuda.Event()
|
||||
consumed_event.record()
|
||||
model_state.buffer_consumed_event = consumed_event
|
||||
|
||||
transferred_layer = model_state.layer_to_transfer
|
||||
|
||||
transferred_layer = model_state.layer_to_transfer
|
||||
assert model_state.new_physical_to_logical_map is not None
|
||||
@@ -936,6 +947,13 @@ class EplbState:
|
||||
new_physical_to_logical_map=model_state.new_physical_to_logical_map,
|
||||
layer=transferred_layer,
|
||||
)
|
||||
|
||||
# Record event after consuming buffer to signal async thread
|
||||
# that it's safe to overwrite the intermediate buffer
|
||||
consumed_event = torch.cuda.Event()
|
||||
consumed_event.record()
|
||||
model_state.buffer_consumed_event = consumed_event
|
||||
|
||||
# After the main thread consumes, advance layer_to_transfer
|
||||
model_state.layer_to_transfer += 1
|
||||
model_state.ep_buffer_ready = 0
|
||||
|
||||
@@ -21,6 +21,10 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
|
||||
is_eplb_enabled = parallel_config.enable_eplb
|
||||
async_eplb = parallel_config.eplb_config.use_async
|
||||
is_deepep_ll = parallel_config.all2all_backend == "deepep_low_latency"
|
||||
is_nccl_based_eplb_communicator = parallel_config.eplb_config.communicator in (
|
||||
"torch_nccl",
|
||||
"pynccl",
|
||||
)
|
||||
|
||||
# Override NCCL_MAX_CTAS to avoid hangs when using async EPLB with the
|
||||
# DeepEP low-latency backend.
|
||||
@@ -39,7 +43,13 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
|
||||
# Limiting NCCL occupancy via NCCL_MAX_CTAS leaves space for the DeepEP
|
||||
# cooperative kernel to launch and complete, breaking the deadlock.
|
||||
# See: https://github.com/deepseek-ai/DeepEP/issues/496
|
||||
if is_data_parallel and is_eplb_enabled and is_deepep_ll and async_eplb:
|
||||
if (
|
||||
is_data_parallel
|
||||
and is_eplb_enabled
|
||||
and is_deepep_ll
|
||||
and async_eplb
|
||||
and is_nccl_based_eplb_communicator
|
||||
):
|
||||
current_value_str = os.getenv("NCCL_MAX_CTAS")
|
||||
|
||||
if current_value_str and current_value_str.isdigit():
|
||||
@@ -49,6 +59,7 @@ def override_envs_for_eplb(parallel_config: ParallelConfig) -> None:
|
||||
os.environ["NCCL_MAX_CTAS"] = str(override_value)
|
||||
logger.info_once(
|
||||
f"EPLB: Setting NCCL_MAX_CTAS={override_value} "
|
||||
"for expert parallel with EPLB and deepep_low_latency backend",
|
||||
"for expert parallel with NCCL-based EPLB communicator and "
|
||||
"deepep_low_latency backend",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
@@ -11,19 +11,9 @@ from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributed import (
|
||||
P2POp,
|
||||
ProcessGroup,
|
||||
all_gather,
|
||||
batch_isend_irecv,
|
||||
get_global_rank,
|
||||
)
|
||||
from torch.distributed import ProcessGroup, all_gather
|
||||
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from .eplb_communicator import EplbCommunicator
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -158,7 +148,8 @@ def move_to_buffer(
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
expert_weights_buffers: Sequence[torch.Tensor],
|
||||
cuda_stream: torch.cuda.Stream | None,
|
||||
ep_group: ProcessGroup,
|
||||
ep_rank: int,
|
||||
communicator: EplbCommunicator,
|
||||
) -> MoveToBufferResult:
|
||||
"""
|
||||
Rearranges expert weights during EPLB rebalancing.
|
||||
@@ -172,7 +163,8 @@ def move_to_buffer(
|
||||
expert_weights: Original expert weights for the layer.
|
||||
expert_weights_buffers: Intermediate buffers (one per tensor).
|
||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||
ep_group: Distributed process group for expert parallel comms.
|
||||
ep_rank: Rank of this process in expert parallel group.
|
||||
communicator: EplbCommunicator instance for P2P communication.
|
||||
|
||||
Returns:
|
||||
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
|
||||
@@ -182,8 +174,6 @@ def move_to_buffer(
|
||||
RecvMetadata: Metadata needed for completing remote weight transfers.
|
||||
"""
|
||||
assert old_indices.shape == new_indices.shape
|
||||
ep_rank = ep_group.rank()
|
||||
|
||||
recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
|
||||
send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
|
||||
send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
|
||||
@@ -247,22 +237,9 @@ def move_to_buffer(
|
||||
expert = new_local_expert_ids[dst]
|
||||
src_local = expert_to_src_map.get(expert, -1)
|
||||
if src_local != -1:
|
||||
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||
b[dst].copy_(w[src_local], non_blocking=True)
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
if isinstance(get_ep_group(), StatelessGroupCoordinator):
|
||||
ep_group = get_ep_group()
|
||||
is_stateless = True
|
||||
else:
|
||||
is_stateless = False
|
||||
|
||||
# Pre-compute global ranks mapping (only needed for non-stateless groups)
|
||||
ep_size = ep_group.size()
|
||||
if not is_stateless:
|
||||
rank_to_global = {
|
||||
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
|
||||
}
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||
b[dst].copy_(w[src_local], non_blocking=True)
|
||||
|
||||
# 2. Post sends
|
||||
if send_count > 0:
|
||||
@@ -294,23 +271,8 @@ def move_to_buffer(
|
||||
if recver_pos < len(ranks_to_recv):
|
||||
recv_ranks.append(ranks_to_recv[recver_pos])
|
||||
for dst in recv_ranks:
|
||||
if is_stateless:
|
||||
for w in expert_weights:
|
||||
op = object.__new__(P2POp)
|
||||
op.op = torch.distributed.isend
|
||||
op.tensor = w[src]
|
||||
op.group_peer = dst
|
||||
p2p_ops.append(op)
|
||||
else:
|
||||
dst_global = rank_to_global[dst]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
w[src],
|
||||
dst_global,
|
||||
)
|
||||
for w in expert_weights
|
||||
]
|
||||
for w in expert_weights:
|
||||
communicator.add_send(w[src], dst)
|
||||
|
||||
# 3. Post recvs
|
||||
if recv_count > 0:
|
||||
@@ -339,40 +301,11 @@ def move_to_buffer(
|
||||
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
||||
else:
|
||||
src = ranks_to_send[recver_pos - remainder_start]
|
||||
if is_stateless:
|
||||
for b in expert_weights_buffers:
|
||||
op = object.__new__(P2POp)
|
||||
op.op = torch.distributed.irecv
|
||||
op.tensor = b[dst]
|
||||
op.group_peer = src
|
||||
p2p_ops.append(op)
|
||||
else:
|
||||
src_global = rank_to_global[src]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
b[dst],
|
||||
src_global,
|
||||
)
|
||||
for b in expert_weights_buffers
|
||||
]
|
||||
for b in expert_weights_buffers:
|
||||
communicator.add_recv(b[dst], src)
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
if p2p_ops and cuda_stream is not None:
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
if is_stateless:
|
||||
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
if is_stateless:
|
||||
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
communicator.execute()
|
||||
# wait for the communication to finish
|
||||
return (
|
||||
is_unchanged,
|
||||
@@ -471,6 +404,7 @@ async def transfer_layer(
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
expert_weights_buffer: Sequence[torch.Tensor],
|
||||
ep_group: ProcessGroup,
|
||||
communicator: EplbCommunicator,
|
||||
is_profile: bool = False,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
@@ -489,6 +423,7 @@ async def transfer_layer(
|
||||
For example, a linear layer may have up and down projection.
|
||||
expert_weights_buffer: Intermediate buffers (one per weight tensor).
|
||||
ep_group: The device process group for expert parallelism.
|
||||
communicator: EplbCommunicator instance for P2P communication.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
@@ -542,7 +477,8 @@ async def transfer_layer(
|
||||
expert_weights=expert_weights,
|
||||
expert_weights_buffers=expert_weights_buffer,
|
||||
cuda_stream=cuda_stream,
|
||||
ep_group=ep_group,
|
||||
ep_rank=ep_group.rank(),
|
||||
communicator=communicator,
|
||||
)
|
||||
return is_unchanged, is_received_locally, recv_metadata
|
||||
|
||||
@@ -552,6 +488,7 @@ def rearrange_expert_weights_inplace(
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||
ep_group: ProcessGroup,
|
||||
communicator: EplbCommunicator,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> None:
|
||||
@@ -569,6 +506,7 @@ def rearrange_expert_weights_inplace(
|
||||
For example, a linear layer may have up and down projection,
|
||||
so weight_count = 2. Each weight's hidden size can be different.
|
||||
ep_group: The device process group for expert parallelism.
|
||||
communicator: EplbCommunicator instance for P2P communication.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
This is used during profile run, where we only perform dummy
|
||||
communications to reserve enough memory for the buffers.
|
||||
@@ -599,6 +537,7 @@ def rearrange_expert_weights_inplace(
|
||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||
|
||||
ep_size = ep_group.size()
|
||||
ep_rank = ep_group.rank()
|
||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||
|
||||
first_layer_weights = list(expert_weights[0])
|
||||
@@ -635,7 +574,8 @@ def rearrange_expert_weights_inplace(
|
||||
expert_weights=expert_weights[layer_idx],
|
||||
expert_weights_buffers=weights_buffer,
|
||||
cuda_stream=None,
|
||||
ep_group=ep_group,
|
||||
ep_rank=ep_rank,
|
||||
communicator=communicator,
|
||||
)
|
||||
|
||||
move_from_buffer(
|
||||
@@ -645,7 +585,7 @@ def rearrange_expert_weights_inplace(
|
||||
is_received_locally=is_received_locally,
|
||||
recv_metadata=recv_metadata,
|
||||
new_indices=new_global_expert_indices_cpu[layer_idx],
|
||||
ep_rank=ep_group.rank(),
|
||||
ep_rank=ep_rank,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1690,11 +1690,7 @@ def initialize_model_parallel(
|
||||
# using torch.distributed in execution with torch.distributed in EPLB.
|
||||
global _EPLB
|
||||
assert _EPLB is None, "EPLB group is already initialized"
|
||||
if (
|
||||
config is not None
|
||||
and config.parallel_config is not None
|
||||
and config.parallel_config.enable_eplb
|
||||
):
|
||||
if config.parallel_config.enable_eplb:
|
||||
if enable_elastic_ep:
|
||||
_EPLB = _init_stateless_group(
|
||||
group_ranks,
|
||||
|
||||
Reference in New Issue
Block a user