Elastic Expert Parallel Initial Support (#20775)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao
2025-07-18 17:46:09 -07:00
committed by GitHub
parent 5782581acf
commit 217937221b
24 changed files with 1659 additions and 68 deletions

View File

@@ -32,7 +32,9 @@ from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor
@@ -77,6 +79,8 @@ class EngineCore:
self.model_executor.register_failure_callback(
executor_fail_callback)
self.available_gpu_memory_for_kv_cache = -1
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
self._initialize_kv_caches(vllm_config)
@@ -137,12 +141,23 @@ class EngineCore:
# Get all kv cache needed by the model
kv_cache_specs = self.model_executor.get_kv_cache_specs()
# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache:
available_gpu_memory = \
self.model_executor.determine_available_memory()
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
dp_group = getattr(self, "dp_group", None)
assert dp_group is not None
self.available_gpu_memory_for_kv_cache = \
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
available_gpu_memory = [
self.available_gpu_memory_for_kv_cache
] * len(kv_cache_specs)
else:
# Profiles the peak memory usage of the model to determine how
# much memory can be allocated for kv cache.
available_gpu_memory = (
self.model_executor.determine_available_memory())
self.available_gpu_memory_for_kv_cache = \
available_gpu_memory[0]
else:
# Attention free models don't need memory for kv cache
available_gpu_memory = [0] * len(kv_cache_specs)
@@ -989,6 +1004,50 @@ class DPEngineCoreProc(EngineCoreProc):
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
stateless_destroy_torch_distributed_process_group(self.dp_group)
self.shutdown()
parallel_config = self.vllm_config.parallel_config
old_dp_size = parallel_config.data_parallel_size
parallel_config.data_parallel_size = \
reconfig_request.new_data_parallel_size
if reconfig_request.new_data_parallel_rank != -1:
parallel_config.data_parallel_rank = \
reconfig_request.new_data_parallel_rank
# local rank specifies device visibility, it should not be changed
assert reconfig_request.new_data_parallel_rank_local == \
ReconfigureRankType.KEEP_CURRENT_RANK
parallel_config.data_parallel_master_ip = \
reconfig_request.new_data_parallel_master_ip
parallel_config.data_parallel_master_port = \
reconfig_request.new_data_parallel_master_port
if reconfig_request.new_data_parallel_rank != -2:
self.dp_rank = parallel_config.data_parallel_rank
self.dp_group = parallel_config.stateless_init_dp_group()
reconfig_request.new_data_parallel_master_port = \
parallel_config.data_parallel_master_port
self.model_executor.reinitialize_distributed(reconfig_request)
if reconfig_request.new_data_parallel_size > old_dp_size:
assert self.available_gpu_memory_for_kv_cache > 0
# pass available_gpu_memory_for_kv_cache from existing
# engine-cores to new engine-cores so they can directly
# use it in _initialize_kv_caches() rather than profiling.
ParallelConfig.sync_kv_cache_memory_size(
self.dp_group, self.available_gpu_memory_for_kv_cache)
# NOTE(yongji): newly joined workers require dummy_run even
# CUDA graph is not used
self.model_executor.collective_rpc("compile_or_warm_up_model")
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
self.shutdown()
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
else:
logger.info("Distributed environment reinitialized for DP rank %s",
self.dp_rank)
class DPEngineCoreActor(DPEngineCoreProc):
"""