elastic_ep: Fix issues with repeated scale up/down cycles (#37131)

Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
Itay Alroy
2026-03-21 01:13:02 +02:00
committed by GitHub
parent e5ed6c6c13
commit c57d38d603
10 changed files with 129 additions and 90 deletions

View File

@@ -338,6 +338,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm.destroy()
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None

View File

@@ -145,6 +145,13 @@ class PyNcclCommunicator:
stream.synchronize()
del data
def destroy(self):
if self.available and not self.disabled:
with torch.accelerator.device_index(self.device.index):
self.nccl.ncclCommDestroy(self.comm)
self.available = False
self.disabled = True
def all_reduce(
self,
in_tensor: torch.Tensor,

View File

@@ -145,11 +145,37 @@ class ElasticEPScalingExecutor:
raise ValueError(f"Unknown execute method: {execute_method}")
return method(*args, **kwargs)
def _set_eplb_suppressed(self, suppressed: bool) -> None:
self.worker.model_runner.eep_eplb_suppressed = suppressed
ep_group = get_standby_ep_group() or get_ep_group()
if ep_group.rank == 0:
logger.info(
"[Elastic EP] EPLB %s elastic scaling transition",
"disabled during" if suppressed else "re-enabled after",
)
def load_model(self) -> None:
(
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
) = self.receive_expert_mapping()
num_physical_experts = expanded_physical_to_logical.shape[1]
self.worker.parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
self.worker.load_model(load_dummy_weights=True)
self.worker.model_runner.setup_eplb_from_mapping(
expanded_physical_to_logical, old_num_physical_experts
)
self._set_eplb_suppressed(True)
def create_standby_groups(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.reconfig_request = reconfig_request
new_dp_size = reconfig_request.new_data_parallel_size
old_dp_size = get_dp_group().world_size
world_size = self.worker.vllm_config.parallel_config.world_size
new_world_size_across_dp = world_size * new_dp_size
updated_config = copy.copy(self.worker.vllm_config)
@@ -165,11 +191,8 @@ class ElasticEPScalingExecutor:
coord_store_port=reconfig_request.coord_store_port,
enable_eplb=updated_config.parallel_config.enable_eplb,
)
self.worker.model_runner.eep_eplb_suppressed = True
standby_ep_group = get_standby_ep_group()
assert standby_ep_group is not None
if standby_ep_group.rank == 0:
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
if new_dp_size > old_dp_size:
self._set_eplb_suppressed(True)
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
standby_dp_group = get_standby_dp_group()
@@ -237,13 +260,31 @@ class ElasticEPScalingExecutor:
device=self.worker.device,
)
def _release_cuda_graphs(self) -> None:
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
wrapper = self.worker.model_runner.model
wrapper.concrete_cudagraph_entries = {}
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
raise RuntimeError("DBO is not yet supported in elastic EP")
torch.compiler.reset()
with set_current_vllm_config(self.worker.vllm_config):
reset_compile_wrapper(self.worker.model_runner.get_model())
gc.collect()
torch.accelerator.synchronize()
torch.accelerator.empty_cache()
def switch_and_remove(self) -> None:
self._release_cuda_graphs()
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
def switch_and_prepare(self) -> None:
old_dp_size = get_dp_group().world_size
old_ep_size = get_ep_group().world_size
self._release_cuda_graphs()
_replace_active_groups(**pop_standby_groups())
parallel_config = self.worker.vllm_config.parallel_config
@@ -384,13 +425,6 @@ class ElasticEPScalingExecutor:
compilation_counter.stock_torch_compile_count += 1
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
# release all previously captured CUDA graphs
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
wrapper = self.worker.model_runner.model
wrapper.concrete_cudagraph_entries = {}
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
raise RuntimeError("DBO is not yet supported in elastic EP")
multi_block_table = self.worker.model_runner.input_batch.block_table
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
for bt in multi_block_table.block_tables:
@@ -399,14 +433,6 @@ class ElasticEPScalingExecutor:
)
multi_block_table.clear()
# reset the compile wrapper
torch.compiler.reset()
with set_current_vllm_config(self.worker.vllm_config):
reset_compile_wrapper(self.worker.model_runner.get_model())
gc.collect()
torch.accelerator.synchronize()
torch.accelerator.empty_cache()
unlock_workspace()
self.worker.compile_or_warm_up_model()
lock_workspace()
@@ -416,8 +442,12 @@ class ElasticEPScalingExecutor:
):
bt.block_table.gpu.copy_(saved_gpu)
bt.block_table.cpu.copy_(saved_cpu)
if new_dp_size < old_dp_size:
self._set_eplb_suppressed(False)
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
def _perform_eplb_reshuffle(
self, rank_mapping: dict[int, int] | None = None
) -> None:
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding...")
@@ -428,20 +458,9 @@ class ElasticEPScalingExecutor:
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
is_async_enabled = eplb_state.is_async
eplb_state.is_async = False
if new_dp_size is None:
if rank_mapping is None:
eplb_state.rearrange()
else:
# scale down
parallel_config = self.worker.vllm_config.parallel_config
tp_size = parallel_config.tensor_parallel_size
old_ep_size = parallel_config.data_parallel_size * tp_size
new_ep_size = new_dp_size * tp_size
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
eplb_state.rearrange(rank_mapping=rank_mapping)
# NOTE(yongji): check whether we need to synchronize here
torch.accelerator.synchronize()
@@ -451,10 +470,25 @@ class ElasticEPScalingExecutor:
eplb_model_state.physical_to_logical_map.shape[1]
)
eplb_state.is_async = is_async_enabled
self.worker.model_runner.eep_eplb_suppressed = False
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed")
def perform_eplb_reshuffle(self) -> None:
self._perform_eplb_reshuffle()
self._set_eplb_suppressed(False)
def perform_scale_down_eplb_reshuffle(self, new_dp_size: int) -> None:
self._set_eplb_suppressed(True)
parallel_config = self.worker.vllm_config.parallel_config
tp_size = parallel_config.tensor_parallel_size
old_ep_size = parallel_config.data_parallel_size * tp_size
new_ep_size = new_dp_size * tp_size
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
self._perform_eplb_reshuffle(rank_mapping=rank_mapping)
def receive_weights(self) -> None:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)

View File

@@ -43,9 +43,10 @@ class ScaleUpExistingEngineState(enum.IntEnum):
class ScaleUpNewEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
PRE_KV_INIT = 0
PREPARE = 1
EPLB_RESHUFFLE = 2
COMPLETE = 3
class ScaleDownRemainingEngineState(enum.IntEnum):
@@ -104,7 +105,7 @@ class ElasticEPScalingState:
self.state: EngineState
if scale_type == "scale_up":
self.state = (
ScaleUpNewEngineState.PREPARE
ScaleUpNewEngineState.PRE_KV_INIT
if worker_type == "new"
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
)
@@ -142,6 +143,12 @@ class ElasticEPScalingState:
else self._progress_remaining_engine()
)
def run_pre_kv_init_states(self) -> None:
assert self.scale_type == "scale_up" and self.worker_type == "new"
assert self.state == ScaleUpNewEngineState.PRE_KV_INIT
assert self.progress()
assert self.state == ScaleUpNewEngineState.PREPARE
def _execute_tcp_store_barrier(
self, dp_store, group_rank, group_size, barrier_id, timeout=None
):
@@ -303,7 +310,23 @@ class ElasticEPScalingState:
state = self.state
assert self.new_dp_group is not None and self.new_dp_store is not None
if state == ScaleUpNewEngineState.PREPARE:
if state == ScaleUpNewEngineState.PRE_KV_INIT:
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("receive_weights",)
)
self.engine_core.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(self.new_dp_group, -1)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("prepare_new_worker",)
)
self.state = ScaleUpNewEngineState.PREPARE
return True
elif state == ScaleUpNewEngineState.PREPARE:
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
torch.distributed.all_reduce(
tensor,
@@ -403,7 +426,6 @@ class ElasticEPScalingState:
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.SHUTDOWN_COMPLETE
)
self.engine_core.shutdown()
return True
else:
@@ -525,7 +547,7 @@ class ElasticEPScalingState:
self.model_executor.collective_rpc(
"elastic_ep_execute",
args=(
"perform_eplb_reshuffle",
"perform_scale_down_eplb_reshuffle",
self.reconfig_request.new_data_parallel_size,
),
)

View File

@@ -1694,6 +1694,8 @@ class DPEngineCoreProc(EngineCoreProc):
if self.eep_scaling_state is not None:
_ = self.eep_scaling_state.progress()
if self.eep_scaling_state.is_complete():
if self.eep_scaling_state.worker_type == "removing":
raise SystemExit
self.process_input_queue_block = True
self.eep_scaling_state = None
@@ -1857,20 +1859,7 @@ class DPEngineCoreProc(EngineCoreProc):
scale_type="scale_up",
reconfig_request=None,
)
self.model_executor.collective_rpc("init_device")
self.model_executor.collective_rpc("load_model")
self._eep_send_engine_core_notification(
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("receive_weights",)
)
self.available_gpu_memory_for_kv_cache = (
ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("prepare_new_worker",)
)
self.eep_scaling_state.run_pre_kv_init_states()
self.process_input_queue_block = False

View File

@@ -602,13 +602,14 @@ class WorkerProc:
)
# Load model
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.worker.init_device()
# Update process title now that parallel groups are initialized
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
self.worker.init_device()
# Update process title now that parallel groups are initialized
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self.worker.elastic_ep_execute("load_model")
else:
self.worker.load_model()
scheduler_config = vllm_config.scheduler_config

View File

@@ -382,9 +382,10 @@ class RayDistributedExecutor(Executor):
all_kwargs.append(kwargs)
self.collective_rpc("init_worker", args=(all_kwargs,))
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
if not is_eep_new_worker:
self.collective_rpc("init_device")
self.collective_rpc("init_device")
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self.collective_rpc("elastic_ep_execute", args=("load_model",))
else:
self.collective_rpc("load_model")
def _update_block_size(worker):

View File

@@ -43,12 +43,14 @@ class UniProcExecutor(Executor):
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
)
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
self.driver_worker.init_worker(all_kwargs=[kwargs])
if not is_eep_new_worker:
self.driver_worker.init_device()
self.driver_worker.init_device()
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
self.driver_worker.elastic_ep_execute("load_model")
else:
self.driver_worker.load_model()
current_platform.update_block_size_for_backend(self.vllm_config)
current_platform.update_block_size_for_backend(self.vllm_config)
def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank)."""

View File

@@ -315,30 +315,12 @@ class Worker(WorkerBase):
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
if dummy_weights:
(
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
) = self.elastic_ep_executor.receive_expert_mapping()
num_physical_experts = expanded_physical_to_logical.shape[1]
self.parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
def load_model(self, *, load_dummy_weights: bool = False) -> None:
with (
self._maybe_get_memory_pool_context(tag="weights"),
set_current_vllm_config(self.vllm_config),
):
self.model_runner.load_model(load_dummy_weights=dummy_weights)
if dummy_weights:
self.model_runner.setup_eplb_from_mapping(
expanded_physical_to_logical, old_num_physical_experts
)
self.model_runner.eep_eplb_suppressed = True
self.model_runner.load_model(load_dummy_weights=load_dummy_weights)
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)

View File

@@ -122,7 +122,7 @@ class WorkerBase:
return format_model_inspection(self.get_model())
def load_model(self) -> None:
def load_model(self, *, load_dummy_weights: bool = False) -> None:
"""Load model onto target device."""
raise NotImplementedError