elastic_ep: Fix issues with repeated scale up/down cycles (#37131)
Signed-off-by: Itay Alroy <ialroy@nvidia.com> Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
@@ -338,6 +338,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def destroy(self):
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm.destroy()
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
|
||||
@@ -145,6 +145,13 @@ class PyNcclCommunicator:
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def destroy(self):
|
||||
if self.available and not self.disabled:
|
||||
with torch.accelerator.device_index(self.device.index):
|
||||
self.nccl.ncclCommDestroy(self.comm)
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
|
||||
def all_reduce(
|
||||
self,
|
||||
in_tensor: torch.Tensor,
|
||||
|
||||
@@ -145,11 +145,37 @@ class ElasticEPScalingExecutor:
|
||||
raise ValueError(f"Unknown execute method: {execute_method}")
|
||||
return method(*args, **kwargs)
|
||||
|
||||
def _set_eplb_suppressed(self, suppressed: bool) -> None:
|
||||
self.worker.model_runner.eep_eplb_suppressed = suppressed
|
||||
ep_group = get_standby_ep_group() or get_ep_group()
|
||||
if ep_group.rank == 0:
|
||||
logger.info(
|
||||
"[Elastic EP] EPLB %s elastic scaling transition",
|
||||
"disabled during" if suppressed else "re-enabled after",
|
||||
)
|
||||
|
||||
def load_model(self) -> None:
|
||||
(
|
||||
expanded_physical_to_logical,
|
||||
num_logical_experts,
|
||||
old_num_physical_experts,
|
||||
) = self.receive_expert_mapping()
|
||||
num_physical_experts = expanded_physical_to_logical.shape[1]
|
||||
self.worker.parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_physical_experts - num_logical_experts
|
||||
)
|
||||
self.worker.load_model(load_dummy_weights=True)
|
||||
self.worker.model_runner.setup_eplb_from_mapping(
|
||||
expanded_physical_to_logical, old_num_physical_experts
|
||||
)
|
||||
self._set_eplb_suppressed(True)
|
||||
|
||||
def create_standby_groups(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
self.reconfig_request = reconfig_request
|
||||
new_dp_size = reconfig_request.new_data_parallel_size
|
||||
old_dp_size = get_dp_group().world_size
|
||||
world_size = self.worker.vllm_config.parallel_config.world_size
|
||||
new_world_size_across_dp = world_size * new_dp_size
|
||||
updated_config = copy.copy(self.worker.vllm_config)
|
||||
@@ -165,11 +191,8 @@ class ElasticEPScalingExecutor:
|
||||
coord_store_port=reconfig_request.coord_store_port,
|
||||
enable_eplb=updated_config.parallel_config.enable_eplb,
|
||||
)
|
||||
self.worker.model_runner.eep_eplb_suppressed = True
|
||||
standby_ep_group = get_standby_ep_group()
|
||||
assert standby_ep_group is not None
|
||||
if standby_ep_group.rank == 0:
|
||||
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
|
||||
if new_dp_size > old_dp_size:
|
||||
self._set_eplb_suppressed(True)
|
||||
|
||||
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
|
||||
standby_dp_group = get_standby_dp_group()
|
||||
@@ -237,13 +260,31 @@ class ElasticEPScalingExecutor:
|
||||
device=self.worker.device,
|
||||
)
|
||||
|
||||
def _release_cuda_graphs(self) -> None:
|
||||
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
|
||||
wrapper = self.worker.model_runner.model
|
||||
wrapper.concrete_cudagraph_entries = {}
|
||||
|
||||
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
|
||||
raise RuntimeError("DBO is not yet supported in elastic EP")
|
||||
|
||||
torch.compiler.reset()
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
reset_compile_wrapper(self.worker.model_runner.get_model())
|
||||
|
||||
gc.collect()
|
||||
torch.accelerator.synchronize()
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
def switch_and_remove(self) -> None:
|
||||
self._release_cuda_graphs()
|
||||
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
|
||||
|
||||
def switch_and_prepare(self) -> None:
|
||||
old_dp_size = get_dp_group().world_size
|
||||
old_ep_size = get_ep_group().world_size
|
||||
|
||||
self._release_cuda_graphs()
|
||||
_replace_active_groups(**pop_standby_groups())
|
||||
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
@@ -384,13 +425,6 @@ class ElasticEPScalingExecutor:
|
||||
compilation_counter.stock_torch_compile_count += 1
|
||||
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
|
||||
|
||||
# release all previously captured CUDA graphs
|
||||
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
|
||||
wrapper = self.worker.model_runner.model
|
||||
wrapper.concrete_cudagraph_entries = {}
|
||||
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
|
||||
raise RuntimeError("DBO is not yet supported in elastic EP")
|
||||
|
||||
multi_block_table = self.worker.model_runner.input_batch.block_table
|
||||
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
|
||||
for bt in multi_block_table.block_tables:
|
||||
@@ -399,14 +433,6 @@ class ElasticEPScalingExecutor:
|
||||
)
|
||||
multi_block_table.clear()
|
||||
|
||||
# reset the compile wrapper
|
||||
torch.compiler.reset()
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
reset_compile_wrapper(self.worker.model_runner.get_model())
|
||||
|
||||
gc.collect()
|
||||
torch.accelerator.synchronize()
|
||||
torch.accelerator.empty_cache()
|
||||
unlock_workspace()
|
||||
self.worker.compile_or_warm_up_model()
|
||||
lock_workspace()
|
||||
@@ -416,8 +442,12 @@ class ElasticEPScalingExecutor:
|
||||
):
|
||||
bt.block_table.gpu.copy_(saved_gpu)
|
||||
bt.block_table.cpu.copy_(saved_cpu)
|
||||
if new_dp_size < old_dp_size:
|
||||
self._set_eplb_suppressed(False)
|
||||
|
||||
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
|
||||
def _perform_eplb_reshuffle(
|
||||
self, rank_mapping: dict[int, int] | None = None
|
||||
) -> None:
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding...")
|
||||
|
||||
@@ -428,20 +458,9 @@ class ElasticEPScalingExecutor:
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
is_async_enabled = eplb_state.is_async
|
||||
eplb_state.is_async = False
|
||||
if new_dp_size is None:
|
||||
if rank_mapping is None:
|
||||
eplb_state.rearrange()
|
||||
else:
|
||||
# scale down
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
old_ep_size = parallel_config.data_parallel_size * tp_size
|
||||
new_ep_size = new_dp_size * tp_size
|
||||
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
|
||||
eplb_state.rearrange(rank_mapping=rank_mapping)
|
||||
# NOTE(yongji): check whether we need to synchronize here
|
||||
torch.accelerator.synchronize()
|
||||
@@ -451,10 +470,25 @@ class ElasticEPScalingExecutor:
|
||||
eplb_model_state.physical_to_logical_map.shape[1]
|
||||
)
|
||||
eplb_state.is_async = is_async_enabled
|
||||
self.worker.model_runner.eep_eplb_suppressed = False
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed")
|
||||
|
||||
def perform_eplb_reshuffle(self) -> None:
|
||||
self._perform_eplb_reshuffle()
|
||||
self._set_eplb_suppressed(False)
|
||||
|
||||
def perform_scale_down_eplb_reshuffle(self, new_dp_size: int) -> None:
|
||||
self._set_eplb_suppressed(True)
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
old_ep_size = parallel_config.data_parallel_size * tp_size
|
||||
new_ep_size = new_dp_size * tp_size
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
self._perform_eplb_reshuffle(rank_mapping=rank_mapping)
|
||||
|
||||
def receive_weights(self) -> None:
|
||||
dp_group = get_dp_group()
|
||||
assert isinstance(dp_group, StatelessGroupCoordinator)
|
||||
|
||||
@@ -43,9 +43,10 @@ class ScaleUpExistingEngineState(enum.IntEnum):
|
||||
|
||||
|
||||
class ScaleUpNewEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
COMPLETE = 2
|
||||
PRE_KV_INIT = 0
|
||||
PREPARE = 1
|
||||
EPLB_RESHUFFLE = 2
|
||||
COMPLETE = 3
|
||||
|
||||
|
||||
class ScaleDownRemainingEngineState(enum.IntEnum):
|
||||
@@ -104,7 +105,7 @@ class ElasticEPScalingState:
|
||||
self.state: EngineState
|
||||
if scale_type == "scale_up":
|
||||
self.state = (
|
||||
ScaleUpNewEngineState.PREPARE
|
||||
ScaleUpNewEngineState.PRE_KV_INIT
|
||||
if worker_type == "new"
|
||||
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
|
||||
)
|
||||
@@ -142,6 +143,12 @@ class ElasticEPScalingState:
|
||||
else self._progress_remaining_engine()
|
||||
)
|
||||
|
||||
def run_pre_kv_init_states(self) -> None:
|
||||
assert self.scale_type == "scale_up" and self.worker_type == "new"
|
||||
assert self.state == ScaleUpNewEngineState.PRE_KV_INIT
|
||||
assert self.progress()
|
||||
assert self.state == ScaleUpNewEngineState.PREPARE
|
||||
|
||||
def _execute_tcp_store_barrier(
|
||||
self, dp_store, group_rank, group_size, barrier_id, timeout=None
|
||||
):
|
||||
@@ -303,7 +310,23 @@ class ElasticEPScalingState:
|
||||
state = self.state
|
||||
assert self.new_dp_group is not None and self.new_dp_store is not None
|
||||
|
||||
if state == ScaleUpNewEngineState.PREPARE:
|
||||
if state == ScaleUpNewEngineState.PRE_KV_INIT:
|
||||
self.engine_core._eep_send_engine_core_notification(
|
||||
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("receive_weights",)
|
||||
)
|
||||
self.engine_core.available_gpu_memory_for_kv_cache = (
|
||||
ParallelConfig.sync_kv_cache_memory_size(self.new_dp_group, -1)
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("prepare_new_worker",)
|
||||
)
|
||||
self.state = ScaleUpNewEngineState.PREPARE
|
||||
return True
|
||||
|
||||
elif state == ScaleUpNewEngineState.PREPARE:
|
||||
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
|
||||
torch.distributed.all_reduce(
|
||||
tensor,
|
||||
@@ -403,7 +426,6 @@ class ElasticEPScalingState:
|
||||
self.engine_core._eep_send_engine_core_notification(
|
||||
EEPNotificationType.SHUTDOWN_COMPLETE
|
||||
)
|
||||
self.engine_core.shutdown()
|
||||
return True
|
||||
|
||||
else:
|
||||
@@ -525,7 +547,7 @@ class ElasticEPScalingState:
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute",
|
||||
args=(
|
||||
"perform_eplb_reshuffle",
|
||||
"perform_scale_down_eplb_reshuffle",
|
||||
self.reconfig_request.new_data_parallel_size,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1694,6 +1694,8 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
if self.eep_scaling_state is not None:
|
||||
_ = self.eep_scaling_state.progress()
|
||||
if self.eep_scaling_state.is_complete():
|
||||
if self.eep_scaling_state.worker_type == "removing":
|
||||
raise SystemExit
|
||||
self.process_input_queue_block = True
|
||||
self.eep_scaling_state = None
|
||||
|
||||
@@ -1857,20 +1859,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
scale_type="scale_up",
|
||||
reconfig_request=None,
|
||||
)
|
||||
self.model_executor.collective_rpc("init_device")
|
||||
self.model_executor.collective_rpc("load_model")
|
||||
self._eep_send_engine_core_notification(
|
||||
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("receive_weights",)
|
||||
)
|
||||
self.available_gpu_memory_for_kv_cache = (
|
||||
ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1)
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("prepare_new_worker",)
|
||||
)
|
||||
self.eep_scaling_state.run_pre_kv_init_states()
|
||||
self.process_input_queue_block = False
|
||||
|
||||
|
||||
|
||||
@@ -602,13 +602,14 @@ class WorkerProc:
|
||||
)
|
||||
|
||||
# Load model
|
||||
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
|
||||
if not is_eep_new_worker:
|
||||
self.worker.init_device()
|
||||
# Update process title now that parallel groups are initialized
|
||||
self.setup_proc_title_and_log_prefix(
|
||||
enable_ep=vllm_config.parallel_config.enable_expert_parallel
|
||||
)
|
||||
self.worker.init_device()
|
||||
# Update process title now that parallel groups are initialized
|
||||
self.setup_proc_title_and_log_prefix(
|
||||
enable_ep=vllm_config.parallel_config.enable_expert_parallel
|
||||
)
|
||||
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
self.worker.elastic_ep_execute("load_model")
|
||||
else:
|
||||
self.worker.load_model()
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
@@ -382,9 +382,10 @@ class RayDistributedExecutor(Executor):
|
||||
all_kwargs.append(kwargs)
|
||||
self.collective_rpc("init_worker", args=(all_kwargs,))
|
||||
|
||||
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
|
||||
if not is_eep_new_worker:
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("init_device")
|
||||
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
self.collective_rpc("elastic_ep_execute", args=("load_model",))
|
||||
else:
|
||||
self.collective_rpc("load_model")
|
||||
|
||||
def _update_block_size(worker):
|
||||
|
||||
@@ -43,12 +43,14 @@ class UniProcExecutor(Executor):
|
||||
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
|
||||
)
|
||||
|
||||
is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
|
||||
self.driver_worker.init_worker(all_kwargs=[kwargs])
|
||||
if not is_eep_new_worker:
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.init_device()
|
||||
|
||||
if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
self.driver_worker.elastic_ep_execute("load_model")
|
||||
else:
|
||||
self.driver_worker.load_model()
|
||||
current_platform.update_block_size_for_backend(self.vllm_config)
|
||||
current_platform.update_block_size_for_backend(self.vllm_config)
|
||||
|
||||
def _distributed_args(self) -> tuple[str, int, int]:
|
||||
"""Return (distributed_init_method, rank, local_rank)."""
|
||||
|
||||
@@ -315,30 +315,12 @@ class Worker(WorkerBase):
|
||||
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
if dummy_weights:
|
||||
(
|
||||
expanded_physical_to_logical,
|
||||
num_logical_experts,
|
||||
old_num_physical_experts,
|
||||
) = self.elastic_ep_executor.receive_expert_mapping()
|
||||
num_physical_experts = expanded_physical_to_logical.shape[1]
|
||||
self.parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_physical_experts - num_logical_experts
|
||||
)
|
||||
|
||||
def load_model(self, *, load_dummy_weights: bool = False) -> None:
|
||||
with (
|
||||
self._maybe_get_memory_pool_context(tag="weights"),
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
):
|
||||
self.model_runner.load_model(load_dummy_weights=dummy_weights)
|
||||
|
||||
if dummy_weights:
|
||||
self.model_runner.setup_eplb_from_mapping(
|
||||
expanded_physical_to_logical, old_num_physical_experts
|
||||
)
|
||||
self.model_runner.eep_eplb_suppressed = True
|
||||
self.model_runner.load_model(load_dummy_weights=load_dummy_weights)
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
|
||||
@@ -122,7 +122,7 @@ class WorkerBase:
|
||||
|
||||
return format_model_inspection(self.get_model())
|
||||
|
||||
def load_model(self) -> None:
|
||||
def load_model(self, *, load_dummy_weights: bool = False) -> None:
|
||||
"""Load model onto target device."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user