[1/N] Elastic EP Milestone 2 (#34861)
Signed-off-by: Yongji Wu <wuyongji317@gmail.com> Signed-off-by: Itay Alroy <ialroy@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com> Co-authored-by: Yongji Wu <wuyongji317@gmail.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
@@ -7,11 +7,10 @@ import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -32,14 +31,12 @@ from vllm.distributed.kv_transfer import (
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
Handle,
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
|
||||
@@ -49,7 +46,6 @@ from vllm.tracing import instrument
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import (
|
||||
AsyncModelRunnerOutput,
|
||||
@@ -124,6 +120,10 @@ class Worker(WorkerBase):
|
||||
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
|
||||
torch.set_float32_matmul_precision(precision)
|
||||
|
||||
from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor
|
||||
|
||||
self.elastic_ep_executor = ElasticEPScalingExecutor(self)
|
||||
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
@@ -317,12 +317,29 @@ class Worker(WorkerBase):
|
||||
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
if dummy_weights:
|
||||
(
|
||||
expanded_physical_to_logical,
|
||||
num_logical_experts,
|
||||
old_num_physical_experts,
|
||||
) = self.elastic_ep_executor.receive_expert_mapping()
|
||||
num_physical_experts = expanded_physical_to_logical.shape[1]
|
||||
self.parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_physical_experts - num_logical_experts
|
||||
)
|
||||
|
||||
with (
|
||||
self._maybe_get_memory_pool_context(tag="weights"),
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
):
|
||||
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
||||
self.model_runner.load_model(load_dummy_weights=dummy_weights)
|
||||
|
||||
if dummy_weights:
|
||||
self.model_runner.setup_eplb_from_mapping(
|
||||
expanded_physical_to_logical, old_num_physical_experts
|
||||
)
|
||||
self.model_runner.eep_eplb_suppressed = True
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
self.model_runner.update_config(overrides)
|
||||
@@ -801,227 +818,6 @@ class Worker(WorkerBase):
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info(
|
||||
"[Elastic EP] Starting expert resharding before scaling down..."
|
||||
)
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=True,
|
||||
global_expert_loads=None,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _eplb_after_scale_up(
|
||||
self,
|
||||
old_ep_size: int,
|
||||
new_ep_size: int,
|
||||
global_expert_loads: list[torch.Tensor] | None,
|
||||
) -> None:
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding after scaling up...")
|
||||
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
|
||||
assert self.model_runner.eplb_state is not None
|
||||
self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=True,
|
||||
global_expert_loads=global_expert_loads,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed!")
|
||||
|
||||
def _reconfigure_parallel_config(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
"""
|
||||
Update parallel config with provided reconfig_request
|
||||
"""
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
|
||||
def _reconfigure_moe(
|
||||
self, old_ep_size: int, new_ep_size: int
|
||||
) -> list[torch.Tensor] | None:
|
||||
"""
|
||||
Reconfigure MoE modules with provided reconfig_request
|
||||
|
||||
Return the global expert load if new_ep_size > old_ep_size,
|
||||
otherwise None
|
||||
"""
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEParallelConfig,
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
|
||||
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
|
||||
return [
|
||||
module
|
||||
for module in model.modules()
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
|
||||
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
|
||||
assert all(
|
||||
module.moe_config.num_local_experts == num_local_experts
|
||||
for module in moe_modules
|
||||
), "All MoE modules must have the same number of experts"
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
tp_size = get_tp_group().world_size
|
||||
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
sp_size = tp_size if is_sequence_parallel else 1
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size,
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
sp_size_=sp_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
return moe_modules
|
||||
|
||||
model_moe_modules = get_moe_modules(self.model_runner.model)
|
||||
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
|
||||
|
||||
update_moe_modules(model_moe_modules, num_local_experts)
|
||||
drafter_model = None
|
||||
if hasattr(self.model_runner, "drafter") and hasattr(
|
||||
self.model_runner.drafter, "model"
|
||||
):
|
||||
drafter_model = self.model_runner.drafter.model
|
||||
if drafter_model is not None and is_mixture_of_experts(drafter_model):
|
||||
drafter_moe_modules = get_moe_modules(drafter_model)
|
||||
# Check if drafter and model have matching configs
|
||||
assert (
|
||||
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
|
||||
), "Drafter and model configs should be the same"
|
||||
update_moe_modules(drafter_moe_modules, num_local_experts)
|
||||
|
||||
if new_ep_size < old_ep_size:
|
||||
num_local_physical_experts = num_local_experts
|
||||
assert self.model_runner.eplb_state is not None
|
||||
new_physical_experts = (
|
||||
self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined]
|
||||
)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts
|
||||
- self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined]
|
||||
)
|
||||
global_expert_loads = None
|
||||
else:
|
||||
num_local_physical_experts_tensor = torch.tensor(
|
||||
[num_local_experts], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
num_local_physical_experts_tensor,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0,
|
||||
)
|
||||
num_local_physical_experts = int(num_local_physical_experts_tensor.item())
|
||||
new_physical_experts = num_local_physical_experts * new_ep_size
|
||||
assert self.model_runner.eplb_state is not None
|
||||
global_expert_loads_any = self.model_runner.eplb_state.rearrange(
|
||||
execute_shuffle=False
|
||||
)
|
||||
global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any)
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
new_physical_experts - global_expert_loads[0].shape[1]
|
||||
)
|
||||
prepare_communication_buffer_for_model(self.model_runner.model)
|
||||
if drafter_model is not None:
|
||||
prepare_communication_buffer_for_model(drafter_model)
|
||||
self.model_runner.model.update_physical_experts_metadata(
|
||||
num_physical_experts=new_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
)
|
||||
return global_expert_loads
|
||||
|
||||
def reinitialize_distributed(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
cleanup_dist_env_and_memory,
|
||||
get_ep_group,
|
||||
)
|
||||
|
||||
old_ep_size = get_ep_group().world_size
|
||||
old_ep_rank = get_ep_group().rank
|
||||
new_ep_size = (
|
||||
reconfig_request.new_data_parallel_size
|
||||
* get_tp_group().world_size
|
||||
* get_pp_group().world_size
|
||||
)
|
||||
if new_ep_size < old_ep_size:
|
||||
self._eplb_before_scale_down(old_ep_size, new_ep_size)
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
||||
):
|
||||
assert old_ep_rank >= new_ep_size
|
||||
# shutdown
|
||||
return
|
||||
|
||||
self._reconfigure_parallel_config(reconfig_request)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
init_worker_distributed_environment(
|
||||
self.vllm_config,
|
||||
self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank,
|
||||
)
|
||||
|
||||
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
|
||||
|
||||
if new_ep_size > old_ep_size:
|
||||
assert global_expert_loads is not None
|
||||
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
@@ -1118,6 +914,9 @@ class Worker(WorkerBase):
|
||||
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
|
||||
weight_transfer_engine.shutdown()
|
||||
|
||||
def elastic_ep_execute(self, execute_method: str, *args, **kwargs):
|
||||
return self.elastic_ep_executor.execute(execute_method, *args, **kwargs)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
Reference in New Issue
Block a user