diff --git a/.buildkite/test_areas/expert_parallelism.yaml b/.buildkite/test_areas/expert_parallelism.yaml index 9a10476ed..1443d847e 100644 --- a/.buildkite/test_areas/expert_parallelism.yaml +++ b/.buildkite/test_areas/expert_parallelism.yaml @@ -20,4 +20,19 @@ steps: - tests/distributed/test_eplb_execute.py commands: - pytest -v -s distributed/test_eplb_execute.py - - pytest -v -s distributed/test_eplb_spec_decode.py \ No newline at end of file + - pytest -v -s distributed/test_eplb_spec_decode.py + +- label: Elastic EP Scaling Test + timeout_in_minutes: 20 + device: b200 + optional: true + working_dir: "/vllm-workspace/tests" + num_devices: 4 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/compilation/ + - tests/distributed/ + commands: + - pytest -v -s distributed/test_elastic_ep.py diff --git a/tests/compile/passes/distributed/test_async_tp.py b/tests/compile/passes/distributed/test_async_tp.py index df7747d1a..abc71768c 100644 --- a/tests/compile/passes/distributed/test_async_tp.py +++ b/tests/compile/passes/distributed/test_async_tp.py @@ -316,7 +316,6 @@ def async_tp_pass_on_test_model( # initialize distributed init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() @@ -334,11 +333,10 @@ def async_tp_pass_on_test_model( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) - async_tp_pass = AsyncTPPass(vllm_config) - - # Set the global vllm_config for TestBackend which calls - # get_current_vllm_config() with set_current_vllm_config(vllm_config): + initialize_model_parallel(tensor_model_parallel_size=world_size) + + async_tp_pass = AsyncTPPass(vllm_config) backend = TestBackend(async_tp_pass) assert ( diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index 6d5113b1e..4beac8c4f 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -278,7 +278,6 @@ def all_reduce_fusion_pass_on_test_model( ) init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) custom_ops = [] if enable_rms_norm_custom_op: @@ -304,6 +303,7 @@ def all_reduce_fusion_pass_on_test_model( model=model_name, trust_remote_code=True, dtype=dtype, seed=42 ) with set_current_vllm_config(vllm_config): + initialize_model_parallel(tensor_model_parallel_size=world_size) all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) diff --git a/tests/compile/passes/distributed/test_sequence_parallelism.py b/tests/compile/passes/distributed/test_sequence_parallelism.py index 46363a9a4..78c3cf92a 100644 --- a/tests/compile/passes/distributed/test_sequence_parallelism.py +++ b/tests/compile/passes/distributed/test_sequence_parallelism.py @@ -242,7 +242,6 @@ def sequence_parallelism_pass_on_test_model( # initialize distributed init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) # configure vllm config for SequenceParallelismPass custom_ops_list = custom_ops.split(",") if custom_ops else [] @@ -272,6 +271,7 @@ def sequence_parallelism_pass_on_test_model( ) with set_current_vllm_config(vllm_config): + initialize_model_parallel(tensor_model_parallel_size=world_size) noop_pass = NoOpEliminationPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) diff --git a/tests/conftest.py b/tests/conftest.py index 22bb19f2f..5a2beea89 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -176,16 +176,20 @@ def init_test_http_connection(): @pytest.fixture def dist_init(): + from tests.utils import ensure_current_vllm_config + temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend="nccl", - ) - initialize_model_parallel(1, 1) - yield + + with ensure_current_vllm_config(): + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(1, 1) + yield cleanup_dist_env_and_memory() diff --git a/tests/distributed/eplb_utils.py b/tests/distributed/eplb_utils.py index 27a63e021..7c27347fd 100644 --- a/tests/distributed/eplb_utils.py +++ b/tests/distributed/eplb_utils.py @@ -7,6 +7,7 @@ import random import torch import torch.multiprocessing as mp +from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.parallel_state import ( init_distributed_environment, ) @@ -42,7 +43,11 @@ def set_env_vars_and_device(env: dict[str, str]) -> None: local_rank = os.environ["LOCAL_RANK"] device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) - init_distributed_environment() + + # Create a minimal vllm config for init_distributed_environment + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + init_distributed_environment() # Ensure each worker process has the same random seed random.seed(42) diff --git a/tests/distributed/test_elastic_ep.py b/tests/distributed/test_elastic_ep.py new file mode 100644 index 000000000..1d0f615d6 --- /dev/null +++ b/tests/distributed/test_elastic_ep.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import subprocess +import time + +import pytest +import requests + +from ..evals.gsm8k.gsm8k_eval import evaluate_gsm8k +from ..utils import RemoteOpenAIServer, multi_gpu_test + + +@pytest.fixture(autouse=True) +def cleanup_ray_between_tests(): + """Force-stop any lingering Ray processes between tests.""" + subprocess.run(["ray", "stop", "--force"], timeout=30, capture_output=True) + time.sleep(5) + yield + + +MODEL_NAME = "deepseek-ai/DeepSeek-V2-Lite-Chat" + +NUM_GSM8K_QUESTIONS = 256 +EXPECTED_ACCURACY = 0.58 +ACCURACY_TOL = 0.08 +MAX_NUM_SEQS = 32 + + +def _send_scale_command(server: RemoteOpenAIServer, new_dp_size: int) -> bool: + url = server.url_for("scale_elastic_ep") + payload = {"new_data_parallel_size": new_dp_size} + headers = {"Content-Type": "application/json"} + + try: + response = requests.post(url, json=payload, headers=headers, timeout=300) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False + + +def _run_gsm8k_eval(server: RemoteOpenAIServer, stage: str) -> float: + assert server.port is not None + result = evaluate_gsm8k( + num_questions=NUM_GSM8K_QUESTIONS, + host=f"http://{server.host}", + port=server.port, + ) + accuracy = result["accuracy"] + print( + f"[{stage}] GSM8K accuracy: {accuracy:.3f} " + f"({result['num_questions']} questions)" + ) + assert accuracy >= EXPECTED_ACCURACY, ( + f"[{stage}] GSM8K accuracy {accuracy:.3f} is below " + f"expected threshold {EXPECTED_ACCURACY}" + ) + return accuracy + + +@multi_gpu_test(num_gpus=4) +def test_elastic_ep_scaling(): + vllm_serve_args = [ + "--trust-remote-code", + "--tensor-parallel-size", + "1", + "--gpu-memory-utilization", + "0.8", + "--max-model-len", + "4096", + "--max-num-seqs", + str(MAX_NUM_SEQS), + "--enable-expert-parallel", + "--all2all-backend", + "allgather_reducescatter", + "--enable-elastic-ep", + "--enable-eplb", + "--eplb-config.num_redundant_experts", + "0", + "--data-parallel-backend", + "ray", + "--data-parallel-size", + "2", + "--api-server-count", + "1", + ] + + leader_address = os.environ.get("LEADER_ADDRESS") + if leader_address: + vllm_serve_args.extend(["--data-parallel-address", leader_address]) + + with RemoteOpenAIServer( + MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200 + ) as server: + initial_accuracy = _run_gsm8k_eval(server, "Initial (2 GPUs)") + + assert _send_scale_command(server, 4) + time.sleep(10) + scale_up_accuracy = _run_gsm8k_eval(server, "After scale up (4 GPUs)") + + assert scale_up_accuracy >= initial_accuracy - ACCURACY_TOL, ( + f"Scale up accuracy {scale_up_accuracy:.3f} dropped more than " + f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}" + ) + + assert _send_scale_command(server, 2) + time.sleep(5) + scale_down_accuracy = _run_gsm8k_eval(server, "After scale down (2 GPUs)") + + assert scale_down_accuracy >= initial_accuracy - ACCURACY_TOL, ( + f"Scale down accuracy {scale_down_accuracy:.3f} dropped more than " + f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}" + ) + + print("\nAccuracy Summary:") + print(f" Initial: {initial_accuracy:.3f}") + print( + f" Scale up: {scale_up_accuracy:.3f} " + f"(diff: {scale_up_accuracy - initial_accuracy:+.3f})" + ) + print( + f" Scale down: {scale_down_accuracy:.3f} " + f"(diff: {scale_down_accuracy - initial_accuracy:+.3f})" + ) + print(f" Tolerance: {ACCURACY_TOL:.3f}") + + +@multi_gpu_test(num_gpus=4) +def test_elastic_ep_scaling_uneven(): + """Test scale up with uneven worker distribution. + + This tests the case where num_new_workers % old_dp_size != 0, + specifically 2 -> 3 where remainder = 1 % 2 = 1. + This exercises the remainder handling in sender-receiver pairing. + """ + vllm_serve_args = [ + "--trust-remote-code", + "--tensor-parallel-size", + "1", + "--gpu-memory-utilization", + "0.8", + "--max-model-len", + "4096", + "--max-num-seqs", + str(MAX_NUM_SEQS), + "--enable-expert-parallel", + "--all2all-backend", + "allgather_reducescatter", + "--enable-elastic-ep", + "--enable-eplb", + "--eplb-config.num_redundant_experts", + "0", + "--data-parallel-backend", + "ray", + "--data-parallel-size", + "2", + "--api-server-count", + "1", + ] + + leader_address = os.environ.get("LEADER_ADDRESS") + if leader_address: + vllm_serve_args.extend(["--data-parallel-address", leader_address]) + + with RemoteOpenAIServer( + MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200 + ) as server: + initial_accuracy = _run_gsm8k_eval(server, "Initial (2 GPUs)") + + # Scale 2 -> 3: This has remainder = 1 % 2 = 1 + # Tests uneven sender-receiver pairing + assert _send_scale_command(server, 3) + time.sleep(10) + scale_up_accuracy = _run_gsm8k_eval(server, "After scale up (3 GPUs)") + + assert scale_up_accuracy >= initial_accuracy - ACCURACY_TOL, ( + f"Scale up accuracy {scale_up_accuracy:.3f} dropped more than " + f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}" + ) + + # Scale back down to 2 + assert _send_scale_command(server, 2) + time.sleep(5) + scale_down_accuracy = _run_gsm8k_eval(server, "After scale down (2 GPUs)") + + assert scale_down_accuracy >= initial_accuracy - ACCURACY_TOL, ( + f"Scale down accuracy {scale_down_accuracy:.3f} dropped more than " + f"{ACCURACY_TOL} below initial accuracy {initial_accuracy:.3f}" + ) + + print("\nAccuracy Summary (Uneven Scaling):") + print(f" Initial: {initial_accuracy:.3f}") + print( + f" Scale up: {scale_up_accuracy:.3f} " + f"(diff: {scale_up_accuracy - initial_accuracy:+.3f})" + ) + print( + f" Scale down: {scale_down_accuracy:.3f} " + f"(diff: {scale_down_accuracy - initial_accuracy:+.3f})" + ) + print(f" Tolerance: {ACCURACY_TOL:.3f}") diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index 48afc39c6..674a665b0 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -8,6 +8,7 @@ import pytest import torch import torch.distributed +from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.eplb.rebalance_execute import ( move_from_buffer, rearrange_expert_weights_inplace, @@ -244,90 +245,95 @@ def _test_async_transfer_layer_without_mtp_worker( num_logical_experts: int, ) -> None: set_env_vars_and_device(env) - ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 - ) - tp_group = get_tp_group() - ep_group = tp_group.device_group - ep_rank = torch.distributed.get_rank() - device = torch.device(f"cuda:{ep_rank}") + vllm_config = VllmConfig() + vllm_config.parallel_config.tensor_parallel_size = world_size - total_physical_experts = world_size * num_local_experts - hidden_sizes = [16, 32] + with set_current_vllm_config(vllm_config): + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) - redundancy_config = create_redundancy_config( - num_logical_experts, - total_physical_experts, - ) - old_indices = create_expert_indices_with_redundancy( - num_layers, - num_logical_experts, - total_physical_experts, - redundancy_config, - ) + tp_group = get_tp_group() + ep_group = tp_group.device_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") - new_redundancy_config = create_redundancy_config( - num_logical_experts, - total_physical_experts, - ) - new_indices = create_expert_indices_with_redundancy( - num_layers, - num_logical_experts, - total_physical_experts, - new_redundancy_config, - ) + total_physical_experts = world_size * num_local_experts + hidden_sizes = [16, 32] - expert_weights = create_expert_weights( - num_layers, - num_local_experts, - hidden_sizes, - ep_rank, - device, - old_indices, - ) - old_indices_cpu = old_indices.cpu() - new_indices_cpu = new_indices.cpu() + redundancy_config = create_redundancy_config( + num_logical_experts, + total_physical_experts, + ) + old_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + redundancy_config, + ) - expert_buffer = [torch.empty_like(w) for w in expert_weights[0]] - cuda_stream = torch.cuda.Stream(device=device) + new_redundancy_config = create_redundancy_config( + num_logical_experts, + total_physical_experts, + ) + new_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + new_redundancy_config, + ) - for layer_idx in range(num_layers): - is_unchanged, is_received_locally, recv_metadata = asyncio.run( - transfer_layer( - old_layer_indices=old_indices_cpu[layer_idx], - new_layer_indices=new_indices_cpu[layer_idx], - expert_weights=expert_weights[layer_idx], - expert_weights_buffer=expert_buffer, - ep_group=ep_group, - cuda_stream=cuda_stream, + expert_weights = create_expert_weights( + num_layers, + num_local_experts, + hidden_sizes, + ep_rank, + device, + old_indices, + ) + old_indices_cpu = old_indices.cpu() + new_indices_cpu = new_indices.cpu() + + expert_buffer = [torch.empty_like(w) for w in expert_weights[0]] + cuda_stream = torch.cuda.Stream(device=device) + + for layer_idx in range(num_layers): + is_unchanged, is_received_locally, recv_metadata = asyncio.run( + transfer_layer( + old_layer_indices=old_indices_cpu[layer_idx], + new_layer_indices=new_indices_cpu[layer_idx], + expert_weights=expert_weights[layer_idx], + expert_weights_buffer=expert_buffer, + ep_group=ep_group, + cuda_stream=cuda_stream, + ) + ) + cuda_stream.synchronize() + move_from_buffer( + expert_weights=expert_weights[layer_idx], + expert_weights_buffers=expert_buffer, + is_unchanged=is_unchanged, + is_received_locally=is_received_locally, + recv_metadata=recv_metadata, + new_indices=new_indices_cpu[layer_idx].numpy(), + ep_rank=ep_rank, ) - ) - cuda_stream.synchronize() - move_from_buffer( - expert_weights=expert_weights[layer_idx], - expert_weights_buffers=expert_buffer, - is_unchanged=is_unchanged, - is_received_locally=is_received_locally, - recv_metadata=recv_metadata, - new_indices=new_indices_cpu[layer_idx].numpy(), - ep_rank=ep_rank, - ) - verify_expert_weights_after_shuffle( - expert_weights, - new_indices, - hidden_sizes, - ep_rank, - num_local_experts, - ) - verify_redundant_experts_have_same_weights( - expert_weights, - new_indices, - hidden_sizes, - world_size, - num_local_experts, - ) + verify_expert_weights_after_shuffle( + expert_weights, + new_indices, + hidden_sizes, + ep_rank, + num_local_experts, + ) + verify_redundant_experts_have_same_weights( + expert_weights, + new_indices, + hidden_sizes, + world_size, + num_local_experts, + ) def _test_rearrange_expert_weights_with_redundancy( @@ -336,71 +342,76 @@ def _test_rearrange_expert_weights_with_redundancy( # Initialize model parallel (using tensor parallel as an entrypoint # to expert parallel) set_env_vars_and_device(env) - ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 - ) - ep_group = get_tp_group().cpu_group - ep_rank = torch.distributed.get_rank() - device = torch.device(f"cuda:{ep_rank}") + vllm_config = VllmConfig() + vllm_config.parallel_config.tensor_parallel_size = world_size - # Test parameters - total_physical_experts = world_size * num_local_experts - hidden_sizes = [32, 64] # Two different weight matrices + with set_current_vllm_config(vllm_config): + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) - # Create old expert indices (with redundancy) - redundancy_config = create_redundancy_config( - num_logical_experts, total_physical_experts - ) + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") - old_indices = create_expert_indices_with_redundancy( - num_layers, - num_logical_experts, - total_physical_experts, - redundancy_config, - ) + # Test parameters + total_physical_experts = world_size * num_local_experts + hidden_sizes = [32, 64] # Two different weight matrices - # Create new expert indices (with redundancy) - new_redundancy_config = create_redundancy_config( - num_logical_experts, total_physical_experts - ) - new_indices = create_expert_indices_with_redundancy( - num_layers, - num_logical_experts, - total_physical_experts, - new_redundancy_config, - ) + # Create old expert indices (with redundancy) + redundancy_config = create_redundancy_config( + num_logical_experts, total_physical_experts + ) - # Create expert weights - expert_weights = create_expert_weights( - num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices - ) + old_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + redundancy_config, + ) - # Execute weight rearrangement - rearrange_expert_weights_inplace( - old_indices, - new_indices, - expert_weights, - ep_group, - is_profile=False, - ) + # Create new expert indices (with redundancy) + new_redundancy_config = create_redundancy_config( + num_logical_experts, total_physical_experts + ) + new_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + new_redundancy_config, + ) - # Verify the rearrangement result - verify_expert_weights_after_shuffle( - expert_weights, - new_indices, - hidden_sizes, - ep_rank, - num_local_experts, - ) + # Create expert weights + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices + ) - verify_redundant_experts_have_same_weights( - expert_weights, - new_indices, - hidden_sizes, - world_size, - num_local_experts, - ) + # Execute weight rearrangement + rearrange_expert_weights_inplace( + old_indices, + new_indices, + expert_weights, + ep_group, + is_profile=False, + ) + + # Verify the rearrangement result + verify_expert_weights_after_shuffle( + expert_weights, + new_indices, + hidden_sizes, + ep_rank, + num_local_experts, + ) + + verify_redundant_experts_have_same_weights( + expert_weights, + new_indices, + hidden_sizes, + world_size, + num_local_experts, + ) @pytest.mark.parametrize( @@ -444,58 +455,63 @@ def test_rearrange_expert_weights_with_redundancy( def _test_rearrange_expert_weights_no_change(env, world_size) -> None: set_env_vars_and_device(env) - ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 - ) - ep_group = get_tp_group().cpu_group - ep_rank = torch.distributed.get_rank() - device = torch.device(f"cuda:{ep_rank}") + vllm_config = VllmConfig() + vllm_config.parallel_config.tensor_parallel_size = world_size - num_layers = 2 - num_local_experts = 2 - total_physical_experts = world_size * num_local_experts - num_logical_experts = total_physical_experts // 2 # Some redundancy - hidden_sizes = [32, 64] + with set_current_vllm_config(vllm_config): + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) - # Create redundancy configuration - redundancy_config = [2] * num_logical_experts + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") - # Same indices - no change - indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, redundancy_config - ) + num_layers = 2 + num_local_experts = 2 + total_physical_experts = world_size * num_local_experts + num_logical_experts = total_physical_experts // 2 # Some redundancy + hidden_sizes = [32, 64] - expert_weights = create_expert_weights( - num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices - ) + # Create redundancy configuration + redundancy_config = [2] * num_logical_experts - # Save original weights - original_weights = [] - for layer_weights in expert_weights: - layer_copy = [] - for weight in layer_weights: - layer_copy.append(weight.clone()) - original_weights.append(layer_copy) + # Same indices - no change + indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, redundancy_config + ) - # Execute rearrangement (should be no change) - rearrange_expert_weights_inplace( - indices, - indices, # Same indices - expert_weights, - ep_group, - is_profile=False, - ) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices + ) - # Verify that the weights have not changed - for layer in range(num_layers): - for weight_idx in range(len(hidden_sizes)): - torch.testing.assert_close( - expert_weights[layer][weight_idx], - original_weights[layer][weight_idx], - msg=f"""Layer {layer}, weight {weight_idx} + # Save original weights + original_weights = [] + for layer_weights in expert_weights: + layer_copy = [] + for weight in layer_weights: + layer_copy.append(weight.clone()) + original_weights.append(layer_copy) + + # Execute rearrangement (should be no change) + rearrange_expert_weights_inplace( + indices, + indices, # Same indices + expert_weights, + ep_group, + is_profile=False, + ) + + # Verify that the weights have not changed + for layer in range(num_layers): + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + expert_weights[layer][weight_idx], + original_weights[layer][weight_idx], + msg=f"""Layer {layer}, weight {weight_idx} should remain unchanged""", - ) + ) @pytest.mark.parametrize( @@ -538,64 +554,69 @@ def test_rearrange_expert_weights_no_change(world_size): def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None: set_env_vars_and_device(env) - ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 - ) - ep_group = get_tp_group().cpu_group - ep_rank = torch.distributed.get_rank() - device = torch.device(f"cuda:{ep_rank}") + vllm_config = VllmConfig() + vllm_config.parallel_config.tensor_parallel_size = world_size - num_layers = 1 - num_local_experts = 2 - total_physical_experts = world_size * num_local_experts - num_logical_experts = total_physical_experts // 2 - hidden_sizes = [32] + with set_current_vllm_config(vllm_config): + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) - # Create different index distributions - old_redundancy = create_redundancy_config( - num_logical_experts, total_physical_experts - ) - new_redundancy = create_redundancy_config( - num_logical_experts, total_physical_experts - ) + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") - old_indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, old_redundancy - ) - new_indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, new_redundancy - ) + num_layers = 1 + num_local_experts = 2 + total_physical_experts = world_size * num_local_experts + num_logical_experts = total_physical_experts // 2 + hidden_sizes = [32] - expert_weights = create_expert_weights( - num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices - ) + # Create different index distributions + old_redundancy = create_redundancy_config( + num_logical_experts, total_physical_experts + ) + new_redundancy = create_redundancy_config( + num_logical_experts, total_physical_experts + ) - # Save original weights - original_weights = [] - for layer_weights in expert_weights: - layer_copy = [] - for weight in layer_weights: - layer_copy.append(weight.clone()) - original_weights.append(layer_copy) + old_indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, old_redundancy + ) + new_indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, new_redundancy + ) - # Execute profile mode rearrangement - rearrange_expert_weights_inplace( - old_indices, - new_indices, - expert_weights, - ep_group, - is_profile=True, # Profile mode - ) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices + ) - # In profile mode, the weights should remain unchanged - for layer in range(num_layers): - for weight_idx in range(len(hidden_sizes)): - torch.testing.assert_close( - expert_weights[layer][weight_idx], - original_weights[layer][weight_idx], - msg="In profile mode, the weights should remain unchanged", - ) + # Save original weights + original_weights = [] + for layer_weights in expert_weights: + layer_copy = [] + for weight in layer_weights: + layer_copy.append(weight.clone()) + original_weights.append(layer_copy) + + # Execute profile mode rearrangement + rearrange_expert_weights_inplace( + old_indices, + new_indices, + expert_weights, + ep_group, + is_profile=True, # Profile mode + ) + + # In profile mode, the weights should remain unchanged + for layer in range(num_layers): + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + expert_weights[layer][weight_idx], + original_weights[layer][weight_idx], + msg="In profile mode, the weights should remain unchanged", + ) @pytest.mark.parametrize("world_size", [2, 4]) diff --git a/tests/distributed/test_nccl_symm_mem_allreduce.py b/tests/distributed/test_nccl_symm_mem_allreduce.py index eeb74bdf5..b81624fe1 100644 --- a/tests/distributed/test_nccl_symm_mem_allreduce.py +++ b/tests/distributed/test_nccl_symm_mem_allreduce.py @@ -10,6 +10,7 @@ import torch.distributed as dist import torch.multiprocessing as mp import vllm.envs as envs +from tests.utils import ensure_current_vllm_config from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops @@ -51,7 +52,8 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int): ) init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) + with ensure_current_vllm_config(): + initialize_model_parallel(tensor_model_parallel_size=world_size) cuda_communicator = typing.cast( CudaCommunicator, get_tp_group().device_communicator diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index c7c9d0602..d20710335 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -9,6 +9,7 @@ import pytest import torch import torch.distributed +from tests.utils import ensure_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary @@ -112,7 +113,8 @@ def test_pynccl_multiple_allreduce(): @worker_fn_wrapper def multiple_allreduce_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") - ensure_model_parallel_initialized(2, 2) + with ensure_current_vllm_config(): + ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) with graph_capture(device=device): # two tp groups can communicate independently diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index 98879ff6e..322e717e9 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -6,7 +6,7 @@ import unittest import pytest import torch -from tests.utils import multi_gpu_test +from tests.utils import ensure_current_vllm_config, multi_gpu_test from vllm.distributed.parallel_state import ( init_distributed_environment, initialize_model_parallel, @@ -87,7 +87,8 @@ def mixer2_gated_norm_tensor_parallel( # initialize distributed init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) + with ensure_current_vllm_config(): + initialize_model_parallel(tensor_model_parallel_size=world_size) # create random weights an inputs weight = torch.rand((hidden_size,), dtype=dtype, device=device) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index d0d8382ac..71180a2c7 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -45,21 +45,24 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): @pytest.fixture def dist_init(): + from tests.utils import ensure_current_vllm_config + temp_file = tempfile.mkstemp()[1] backend = "nccl" if current_platform.is_cpu() or current_platform.is_tpu(): backend = "gloo" - init_distributed_environment( - world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend=backend, - ) - initialize_model_parallel(1, 1) - yield + with ensure_current_vllm_config(): + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend=backend, + ) + initialize_model_parallel(1, 1) + yield cleanup_dist_env_and_memory(shutdown_ray=True) diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 382999bca..b2db7968e 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -6,7 +6,7 @@ import random import pytest import torch -from tests.utils import multi_gpu_test +from tests.utils import ensure_current_vllm_config, multi_gpu_test from vllm import _custom_ops as ops from vllm.distributed import ( init_distributed_environment, @@ -631,7 +631,8 @@ def use_fused_moe_lora_kernel_tensor_parallel( local_rank=local_rank, distributed_init_method=init_method, ) - initialize_model_parallel(world_size, 1) + with ensure_current_vllm_config(): + initialize_model_parallel(world_size, 1) tp_size = get_tensor_model_parallel_world_size() input_dim = K if column_parallel else N diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 445aaf9cb..274142e8d 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -13,6 +13,7 @@ from vllm.config import ( ParallelConfig, SchedulerConfig, VllmConfig, + set_current_vllm_config, ) from vllm.config.load import LoadConfig from vllm.config.lora import LoRAConfig @@ -77,8 +78,9 @@ def test_worker_apply_lora(qwen3_lora_files): distributed_init_method=f"file://{tempfile.mkstemp()[1]}", ) - worker.init_device() - worker.load_model() + with set_current_vllm_config(vllm_config): + worker.init_device() + worker.load_model() set_active_loras(worker, []) assert worker.list_loras() == set() diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 24e49e9d6..17d82b125 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -6,7 +6,7 @@ import pytest import torch import torch.multiprocessing as mp -from tests.utils import multi_gpu_test +from tests.utils import ensure_current_vllm_config, multi_gpu_test from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import ( init_distributed_environment, @@ -117,7 +117,8 @@ def run_dp_sharded_vision_model_vs_direct( # initialize distributed init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) + with ensure_current_vllm_config(): + initialize_model_parallel(tensor_model_parallel_size=world_size) # Create a test input tensor image_input = torch.randn(batch_size, 3, 224, 224) @@ -302,7 +303,8 @@ def run_dp_sharded_mrope_vision_model_vs_direct( # initialize distributed init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) + with ensure_current_vllm_config(): + initialize_model_parallel(tensor_model_parallel_size=world_size) # Create test data grid_thw_list = [] @@ -377,7 +379,8 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker( ) init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) + with ensure_current_vllm_config(): + initialize_model_parallel(tensor_model_parallel_size=world_size) # Create empty inputs pixel_values = torch.empty((0, 768)) @@ -425,7 +428,8 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker( ) init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=world_size) + with ensure_current_vllm_config(): + initialize_model_parallel(tensor_model_parallel_size=world_size) # Create images with very different sizes grid_thw_list = [ diff --git a/tests/utils.py b/tests/utils.py index 4041c2617..d407733a3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -895,6 +895,36 @@ def compare_all_settings( ) +@contextmanager +def ensure_current_vllm_config(): + """Ensures a vllm config is set for the duration of the context. + + If a config is already set, this is a no-op. Otherwise, it creates a default + VllmConfig and sets it for the duration of the context. + + Used for tests that call functions which require a vllm config but don't + need a specific config. + + Example: + with ensure_current_vllm_config(): + init_distributed_environment(...) + ensure_model_parallel_initialized(...) + """ + from vllm.config import ( + VllmConfig, + get_current_vllm_config_or_none, + set_current_vllm_config, + ) + + if get_current_vllm_config_or_none() is not None: + # Config already set, just yield + yield + else: + # No config set, create a default one for the duration + with set_current_vllm_config(VllmConfig()): + yield + + def init_test_distributed_environment( tp_size: int, pp_size: int, @@ -921,6 +951,7 @@ def init_test_distributed_environment( distributed_init_method=distributed_init_method, local_rank=local_rank, ) + ensure_model_parallel_initialized(tp_size, pp_size) else: # No config set, create a default one for the test with set_current_vllm_config(VllmConfig()): @@ -930,7 +961,7 @@ def init_test_distributed_environment( distributed_init_method=distributed_init_method, local_rank=local_rank, ) - ensure_model_parallel_initialized(tp_size, pp_size) + ensure_model_parallel_initialized(tp_size, pp_size) def multi_process_parallel( diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 93e6822e6..d1c43b645 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -789,8 +789,11 @@ def test_hybrid_attention_mamba_tensor_shapes(): "MASTER_PORT": "12345", } ) - init_distributed_environment() - initialize_model_parallel(tensor_model_parallel_size=1) + from tests.utils import ensure_current_vllm_config + + with ensure_current_vllm_config(): + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=1) torch.set_default_dtype(torch.float16) model_config = ModelConfig( diff --git a/tests/v1/worker/test_worker_memory_snapshot.py b/tests/v1/worker/test_worker_memory_snapshot.py index 66330127b..27a9b4a75 100644 --- a/tests/v1/worker/test_worker_memory_snapshot.py +++ b/tests/v1/worker/test_worker_memory_snapshot.py @@ -10,6 +10,7 @@ from unittest.mock import patch import pytest import torch +from vllm.config import set_current_vllm_config from vllm.engine.arg_utils import EngineArgs from vllm.utils.mem_utils import MemorySnapshot from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment @@ -95,7 +96,12 @@ def worker_process( side_effect=make_operation_tracker("nccl_all_reduce", original_all_reduce), ) - with init_patch, memory_patch, all_reduce_patch: + with ( + init_patch, + memory_patch, + all_reduce_patch, + set_current_vllm_config(vllm_config), + ): # Initialize device (this is where we test the order) worker.init_device() diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 850ddae9a..5dff296d0 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper: yield finally: self.__class__.forward.__code__ = original + + +def reset_compile_wrapper(model: torch.nn.Module) -> None: + """ + Clean up compiled model and captured CUDA graphs for elastic EP. + """ + if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr( + model, "model" + ): + model = model.model + if not isinstance(model, TorchCompileWithNoGuardsWrapper): + return + # model.do_not_compile is set by the @support_torch_compile decorator + if hasattr(model, "do_not_compile") and model.do_not_compile: + return + from vllm.compilation.counter import compilation_counter + + # reset the compilation counter + compilation_counter.num_models_seen = 0 + compilation_counter.num_graphs_seen = 0 + compilation_counter.num_piecewise_graphs_seen = 0 + compilation_counter.num_piecewise_capturable_graphs_seen = 0 + compilation_counter.num_backend_compilations = 0 + compilation_counter.num_gpu_runner_capture_triggers = 0 + compilation_counter.num_cudagraph_captured = 0 + compilation_counter.num_inductor_compiles = 0 + compilation_counter.num_eager_compiles = 0 + compilation_counter.num_cache_entries_updated = 0 + compilation_counter.num_compiled_artifacts_saved = 0 + compilation_counter.stock_torch_compile_count = 0 + + # Clear the AOT compiled function so the model is forced to + # recompile on the next call. Without this, decorators.py + # __call__ uses the stale aot_compiled_fn whose torchinductor + # kernels have old parameters (expert_map size for example) + # baked in as compile-time constants. + if hasattr(model, "aot_compiled_fn"): + model.aot_compiled_fn = None + if hasattr(model, "was_aot_compile_fn_loaded_from_disk"): + model.was_aot_compile_fn_loaded_from_disk = False + + # Reset the cache_dir so VllmBackend recomputes the hash + # (data_parallel_size changed, so the config hash differs). + compilation_config = model.vllm_config.compilation_config + compilation_config.cache_dir = "" + compilation_config.local_cache_dir = "" + + model.__class__.forward.__code__ = model.original_code_object() + TorchCompileWithNoGuardsWrapper.__init__(model) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index fa4f72dcc..59df4a214 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -165,6 +165,9 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" + enable_elastic_ep: bool = False + """Enable elastic expert parallelism with stateless NCCL groups for DP/EP.""" + enable_dbo: bool = False """Enable dual batch overlap for the model executor.""" ubatch_size: int = 0 @@ -244,6 +247,34 @@ class ParallelConfig: Set to be private as it's not intended to be configured by users. """ + _stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless DP groups when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + It is a list of list[int], with each inner list contains a set of 3 ports + to be used for setting up the stateless CPU/device/TCPStore groups + in StatelessGroupCoordinator. The number of inner lists is equal to + the number of DP groups, + i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size, + and len(self._stateless_dp_group_port_list[i]) == 3 for all i. + """ + + _stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless EP groups when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size, + """ + + _stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless EPLB groups when enable_elastic_ep is True. + Same topology as EP but separate NCCL communicator to avoid deadlocks. + """ + + _stateless_world_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless world group when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + len(self._stateless_world_group_port_list) == 1, + """ + decode_context_parallel_size: int = 1 """Number of decode context parallel groups, because the world size does not change by dcp, it simply reuse the GPUs of TP group, and tp_size @@ -402,7 +433,67 @@ class ParallelConfig: return answer - def stateless_init_dp_group(self) -> ProcessGroup: + def allocate_elastic_ep_ports(self) -> None: + """Allocate all ports for elastic EP (stateless groups + DP master). + + Must be called AFTER ray.init() so that ports claimed by Ray's + idle worker pool are already in use and won't be returned by + get_open_ports_list(). + """ + if not self.enable_elastic_ep: + return + if self._stateless_world_group_port_list: + return + + num_world_groups = 1 + dp_size = self.data_parallel_size + ep_size = self.data_parallel_size * self.world_size_across_dp + num_dp_groups = max(1, self.world_size_across_dp // dp_size) + num_ep_groups = max(1, self.world_size_across_dp // ep_size) + num_eplb_groups = num_ep_groups + total_stateless_ports = ( + num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups + ) * 3 + num_dp_master_ports = 5 + + all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports) + + self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:] + self.data_parallel_master_port = self._data_parallel_master_port_list.pop() + all_ports = all_ports[:-num_dp_master_ports] + + self._stateless_world_group_port_list = [ + all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3) + ] + start_idx = num_world_groups * 3 + self._stateless_dp_group_port_list = [ + all_ports[i : i + 3] + for i in range(start_idx, start_idx + num_dp_groups * 3, 3) + ] + start_idx += num_dp_groups * 3 + self._stateless_ep_group_port_list = [ + all_ports[i : i + 3] + for i in range(start_idx, start_idx + num_ep_groups * 3, 3) + ] + start_idx += num_ep_groups * 3 + self._stateless_eplb_group_port_list = [ + all_ports[i : i + 3] + for i in range(start_idx, start_idx + num_eplb_groups * 3, 3) + ] + + def get_next_stateless_world_group_port(self) -> list[int]: + return self._stateless_world_group_port_list.pop() + + def get_next_stateless_dp_group_port(self) -> list[int]: + return self._stateless_dp_group_port_list.pop() + + def get_next_stateless_ep_group_port(self) -> list[int]: + return self._stateless_ep_group_port_list.pop() + + def get_next_stateless_eplb_group_port(self) -> list[int]: + return self._stateless_eplb_group_port_list.pop() + + def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup: # NOTE: In high-concurrency scenarios multiple processes # can pick the same (currently free) port through a race # condition when calling `get_open_port()`. When the first @@ -426,7 +517,8 @@ class ParallelConfig: self.get_next_dp_init_port(), self.data_parallel_rank, self.data_parallel_size, - backend=current_platform.dist_backend, + backend="gloo", + return_store=return_store, ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. @@ -561,6 +653,21 @@ class ParallelConfig: logger.info("Using external launcher for distributed inference.") self.world_size *= self.data_parallel_size + if self.enable_elastic_ep: + if not self.enable_eplb: + raise ValueError("Elastic EP is only supported with enable_eplb=True.") + if self.pipeline_parallel_size > 1: + raise ValueError( + "Elastic EP is not supported with pipeline parallelism " + f"(pipeline_parallel_size={self.pipeline_parallel_size})." + ) + if self.data_parallel_external_lb or self.data_parallel_hybrid_lb: + raise NotImplementedError( + "Elastic EP is not compatible with data_parallel_external_lb " + "or data_parallel_hybrid_lb. Elastic EP relies on a single API " + "server and core client to coordinate scale up/down." + ) + if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. if self.distributed_executor_backend == "external_launcher": @@ -573,9 +680,12 @@ class ParallelConfig: "Set data_parallel_rank to %d automatically.", self.data_parallel_rank, ) - if not self._data_parallel_master_port_list: - self._data_parallel_master_port_list = get_open_ports_list(5) - self.data_parallel_master_port = self._data_parallel_master_port_list.pop() + if not self.enable_elastic_ep: + if not self._data_parallel_master_port_list: + self._data_parallel_master_port_list = get_open_ports_list(5) + self.data_parallel_master_port = ( + self._data_parallel_master_port_list.pop() + ) if not (0 <= self.data_parallel_rank < self.data_parallel_size): raise ValueError( @@ -602,7 +712,7 @@ class ParallelConfig: os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" logger.info("Disabling V1 multiprocessing for external launcher.") - if self.distributed_executor_backend is None and self.world_size > 1: + if self.distributed_executor_backend is None and self.world_size_across_dp > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 4acab4e3c..3efcebd54 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -31,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase): debugging. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def naive_multicast( self, @@ -138,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase): all-gather (dispatch) and reduce-scatter (combine). """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def dispatch_router_logits( self, @@ -239,12 +239,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): All2All communication based on DeepEP High-Throughput kernels. """ - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): assert has_deep_ep(), ( "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md" " to install DeepEP kernels." ) # noqa - super().__init__(cpu_group) + super().__init__(cpu_group, tcp_store_group) self.handle_cache = Cache() # This is the DeepEP default. Stick to it till we can establish @@ -282,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): raise NotImplementedError def destroy(self): - pass + with self.handle_cache._lock: + for _, handle in self.handle_cache._cache.items(): + handle.destroy() + self.handle_cache._cache.clear() class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): @@ -290,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): All2All communication based on DeepEP High-Throughput kernels. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def _make_all2all_kwargs(self) -> dict[Any, Any]: # Defaults for internode and intranode are taken from DeepEP tests. @@ -314,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): num_rdma_bytes=num_rdma_bytes, low_latency_mode=False, num_qps_per_rank=num_qps_per_rank, + explicitly_destroy=True, ) def get_handle(self, kwargs): @@ -347,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): All2All communication based on DeepEP Low-Latency kernels. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def _make_all2all_kwargs( self, @@ -387,6 +391,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): num_qps_per_rank=num_qps_per_rank, allow_nvlink_for_low_latency_mode=True, allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL, + explicitly_destroy=True, ) def get_handle(self, kwargs): @@ -418,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase): rank: int world_size: int - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): assert has_flashinfer_all2all(), ( "flashinfer all2all module not found. Please install/check flashinfer" ) # noqa - super().__init__(cpu_group) + super().__init__(cpu_group, tcp_store_group) logger.debug( "Initialize for flashinfer All2All rank=%d, world size=%d", self.rank, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 572bac80f..2125f7381 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -29,8 +29,9 @@ class All2AllManagerBase: rank: int world_size: int - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): self.cpu_group = cpu_group + self.tcp_store_group = tcp_store_group # compute some common properties from vllm.distributed.parallel_state import ( @@ -47,12 +48,17 @@ class All2AllManagerBase: # when we create this object self.dp_rank = self.dp_group.rank_in_group self.dp_world_size = self.dp_group.world_size - self.rank = dist.get_rank(cpu_group) - self.world_size = dist.get_world_size(cpu_group) + self.rank = cpu_group.rank() + self.world_size = cpu_group.size() # all2all communication often has separate implementations for # intra-node and inter-node communication - self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + if tcp_store_group is None: + self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + else: + self.internode = not all( + in_the_same_node_as(tcp_store_group, source_rank=0) + ) def get_handle(self, kwargs): # get a handle for the all2all communication, @@ -121,17 +127,36 @@ class DeviceCommunicatorBase: device: torch.device | None = None, device_group: ProcessGroup | None = None, unique_name: str = "", + global_ranks: list[int] | None = None, + global_world_size: int | None = None, ): self.device = device or torch.device("cpu") self.cpu_group = cpu_group self.device_group = device_group self.unique_name = unique_name - self.rank = dist.get_rank(cpu_group) - self.world_size = dist.get_world_size(cpu_group) - self.ranks = dist.get_process_group_ranks(cpu_group) - self.global_rank = dist.get_rank() - self.global_world_size = dist.get_world_size() - self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) + + # Check if this is a stateless process group + from torch.distributed.distributed_c10d import _world + + is_stateless = _world.pg_map.get(cpu_group, None) is None + + if is_stateless: + # For stateless groups, we can't use torch.distributed methods + self.rank = cpu_group.rank() + self.world_size = cpu_group.size() + assert global_ranks is not None + assert global_world_size is not None + self.ranks = global_ranks + self.global_rank = self.ranks[self.rank] + self.global_world_size = global_world_size + self.rank_in_group = self.rank + else: + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + self.ranks = dist.get_process_group_ranks(cpu_group) + self.global_rank = dist.get_rank() + self.global_world_size = dist.get_world_size() + self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) use_ep = False all2all_backend = None @@ -145,7 +170,7 @@ class DeviceCommunicatorBase: use_ep = config.parallel_config.data_parallel_size > 1 all2all_backend = config.parallel_config.all2all_backend - self.is_ep_communicator = "ep" in unique_name + self.is_ep_communicator = unique_name.split(":")[0] == "ep" self.use_all2all = self.is_ep_communicator and use_ep self.all2all_backend = all2all_backend self.all2all_manager: All2AllManagerBase | None = None @@ -275,6 +300,13 @@ class DeviceCommunicatorBase: torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast a tensor from source rank to all ranks.""" + if self.world_size == 1: + return tensor + torch.distributed.broadcast(tensor, self.ranks[src], self.device_group) + return tensor + def destroy(self): pass @@ -343,3 +375,6 @@ class DeviceCommunicatorBase: This is a no-op in the base class. """ return hidden_states + + def batch_isend_irecv(self, p2p_ops: list): + raise NotImplementedError diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index dd571482f..5e18dbde9 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import ( from vllm.logger import init_logger from vllm.platforms import current_platform +from ..utils import StatelessProcessGroup from .base_device_communicator import DeviceCommunicatorBase logger = init_logger(__name__) @@ -28,8 +29,18 @@ class CudaCommunicator(DeviceCommunicatorBase): device: torch.device | None = None, device_group: ProcessGroup | None = None, unique_name: str = "", + global_ranks: list[int] | None = None, + global_world_size: int | None = None, + tcp_store_group: StatelessProcessGroup | None = None, ): - super().__init__(cpu_group, device, device_group, unique_name) + super().__init__( + cpu_group, + device, + device_group, + unique_name, + global_ranks, + global_world_size, + ) if "tp" not in unique_name: # custom allreduce or torch symm mem can be used only by tp use_custom_allreduce = False @@ -62,7 +73,7 @@ class CudaCommunicator(DeviceCommunicatorBase): self.pynccl_comm: PyNcclCommunicator | None = None if self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( - group=self.cpu_group, + group=self.cpu_group if tcp_store_group is None else tcp_store_group, device=self.device, ) if is_symmetric_memory_enabled(): @@ -107,19 +118,27 @@ class CudaCommunicator(DeviceCommunicatorBase): if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager - self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + self.all2all_manager = NaiveAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "allgather_reducescatter": from .all2all import AgRsAll2AllManager - self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + self.all2all_manager = AgRsAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "deepep_high_throughput": from .all2all import DeepEPHTAll2AllManager - self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) + self.all2all_manager = DeepEPHTAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "deepep_low_latency": from .all2all import DeepEPLLAll2AllManager - self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) + self.all2all_manager = DeepEPLLAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "mori": from .all2all import MoriAll2AllManager @@ -127,7 +146,9 @@ class CudaCommunicator(DeviceCommunicatorBase): elif self.all2all_backend == "flashinfer_all2allv": from .all2all import FlashInferAllToAllManager - self.all2all_manager = FlashInferAllToAllManager(self.cpu_group) + self.all2all_manager = FlashInferAllToAllManager( + self.cpu_group, tcp_store_group + ) else: raise ValueError(f"Unknown all2all backend: {self.all2all_backend}") @@ -284,6 +305,18 @@ class CudaCommunicator(DeviceCommunicatorBase): torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast a tensor from source rank to all ranks.""" + if self.world_size == 1: + return tensor + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.broadcast(tensor, src) + return tensor + else: + raise ValueError("No PyNCCL communicator found") + def destroy(self): if self.pynccl_comm is not None: self.pynccl_comm = None @@ -403,3 +436,10 @@ class CudaCommunicator(DeviceCommunicatorBase): hidden_states, is_sequence_parallel, ) + + def batch_isend_irecv(self, p2p_ops: list): + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.batch_isend_irecv(p2p_ops) + else: + raise ValueError("No PyNCCL communicator found") diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 2fc35e80f..44dc113e4 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -312,10 +312,19 @@ class PyNcclCommunicator: ) if stream is None: stream = current_stream() + if tensor.dtype in [ + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ]: + nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8) + else: + nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype) self.nccl.ncclSend( buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), + nccl_dtype, dst, self.comm, cudaStream_t(stream.cuda_stream), @@ -330,10 +339,19 @@ class PyNcclCommunicator: ) if stream is None: stream = current_stream() + if tensor.dtype in [ + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ]: + nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8) + else: + nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype) self.nccl.ncclRecv( buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), + nccl_dtype, src, self.comm, cudaStream_t(stream.cuda_stream), @@ -384,3 +402,17 @@ class PyNcclCommunicator: def deregister_comm_window(self, window): return self.nccl.ncclCommWindowDeregister(self.comm, window) + + def batch_isend_irecv(self, p2p_ops: list, stream=None): + if self.disabled: + return + if stream is None: + stream = current_stream() + self.group_start() + for op in p2p_ops: + if op.op is torch.distributed.isend: + self.send(op.tensor, op.group_peer, stream) + elif op.op is torch.distributed.irecv: + self.recv(op.tensor, op.group_peer, stream) + + self.group_end() diff --git a/vllm/distributed/elastic_ep/__init__.py b/vllm/distributed/elastic_ep/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py new file mode 100644 index 000000000..22d570660 --- /dev/null +++ b/vllm/distributed/elastic_ep/elastic_execute.py @@ -0,0 +1,529 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +import gc +import weakref +from collections.abc import Iterable, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import P2POp + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.wrapper import reset_compile_wrapper +from vllm.config import ( + CompilationMode, + set_current_vllm_config, +) +from vllm.distributed import ( + get_dp_group, + get_ep_group, + get_pcp_group, + get_tp_group, +) +from vllm.distributed.elastic_ep.standby_state import ( + create_standby_groups, + get_standby_dp_group, + get_standby_ep_group, + pop_standby_groups, +) +from vllm.distributed.parallel_state import ( + _replace_active_groups, + prepare_communication_buffer_for_model, +) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper +from vllm.v1.worker.workspace import lock_workspace, unlock_workspace + +logger = init_logger(__name__) + + +def batch_transfer_weights( + model: nn.Module, + is_sender: bool, + peer_rank: int, + dp_group: StatelessGroupCoordinator, + expert_weights: Sequence[Iterable[torch.Tensor]], +) -> None: + device_comm = dp_group.device_communicator + if device_comm is None: + raise ValueError("No device communicator found") + + expert_weights_set = set() + for weight_group in expert_weights: + for weight in weight_group: + expert_weights_set.add(weight.data_ptr()) + + state_dict = model.state_dict() + all_params = [] + + for name, param in state_dict.items(): + if name.endswith("expert_map"): + continue + if param.data_ptr() not in expert_weights_set: + all_params.append(param.data) + + assert len(all_params) > 0 + p2p_ops = [] + for param in all_params: + op = object.__new__(P2POp) + if is_sender: + op.op = torch.distributed.isend + op.tensor = param + else: + op.op = torch.distributed.irecv + op.tensor = param + op.group_peer = peer_rank + p2p_ops.append(op) + device_comm.batch_isend_irecv(p2p_ops) + + +def broadcast_expert_mapping( + physical_to_logical: torch.Tensor | None, + num_local_physical_experts: int | None, + num_logical_experts: int | None, + dp_group: StatelessGroupCoordinator, + device: torch.device, + src_rank: int = 0, +) -> tuple[torch.Tensor, int, int]: + if dp_group.rank_in_group == src_rank: + assert physical_to_logical is not None + assert num_local_physical_experts is not None + assert num_logical_experts is not None + assert physical_to_logical.dtype == torch.int64 + shape_tensor = torch.tensor( + list(physical_to_logical.shape), dtype=torch.int64, device="cpu" + ) + metadata_tensor = torch.tensor( + [num_local_physical_experts, num_logical_experts], + dtype=torch.int64, + device="cpu", + ) + else: + shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu") + metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu") + + shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank) + metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank) + + if dp_group.rank_in_group != src_rank: + assert device is not None + physical_to_logical = torch.empty( + tuple(shape_tensor.tolist()), + dtype=torch.int64, + device=device, + ) + + assert physical_to_logical is not None + physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank) + num_local_physical_experts = int(metadata_tensor[0].item()) + num_logical_experts = int(metadata_tensor[1].item()) + + return physical_to_logical, num_local_physical_experts, num_logical_experts + + +class ElasticEPScalingExecutor: + def __init__(self, worker): + self.worker_ref = weakref.ref(worker) + self.reconfig_request = None + + @property + def worker(self): + worker = self.worker_ref() + if worker is None: + raise RuntimeError("Worker has been garbage collected") + return worker + + def execute(self, execute_method: str, *args, **kwargs): + method = getattr(self, execute_method, None) + if method is None: + raise ValueError(f"Unknown execute method: {execute_method}") + return method(*args, **kwargs) + + def create_standby_groups( + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: + self.reconfig_request = reconfig_request + new_dp_size = reconfig_request.new_data_parallel_size + world_size = self.worker.vllm_config.parallel_config.world_size + new_world_size_across_dp = world_size * new_dp_size + updated_config = copy.copy(self.worker.vllm_config) + updated_config.parallel_config = copy.deepcopy( + self.worker.vllm_config.parallel_config + ) + updated_config.parallel_config.data_parallel_size = new_dp_size + with set_current_vllm_config(updated_config): + create_standby_groups( + new_dp_size=new_dp_size, + new_world_size_across_dp=new_world_size_across_dp, + master_ip=reconfig_request.new_data_parallel_master_ip, + world_group_ports=reconfig_request.new_stateless_world_group_port_list, + dp_group_ports=reconfig_request.new_stateless_dp_group_port_list, + ep_group_ports=reconfig_request.new_stateless_ep_group_port_list, + eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list, + ) + self.worker.model_runner.eep_eplb_suppressed = True + standby_ep_group = get_standby_ep_group() + assert standby_ep_group is not None + if standby_ep_group.rank == 0: + logger.info("[Elastic EP] EPLB disabled during elastic scaling transition") + + def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None: + standby_dp_group = get_standby_dp_group() + assert standby_dp_group is not None + # Broadcast old_dp_size to all workers in standby group + if standby_dp_group.rank_in_group < old_dp_size: + old_dp_size_tensor = torch.tensor( + [old_dp_size], dtype=torch.int64, device="cpu" + ) + else: + old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast( + old_dp_size_tensor, 0 + ) + + num_new_workers = new_dp_size - old_dp_size + dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank + + # Sender-receiver pairing: the first new_workers % old_dp_size + # senders get (k+1) contiguous receivers, the rest get k + # receivers. + num_dst_per_sender = num_new_workers // old_dp_size + remainder = num_new_workers % old_dp_size + + if dp_rank < remainder: + recv_begin = dp_rank * (num_dst_per_sender + 1) + recv_end = recv_begin + num_dst_per_sender + 1 + else: + recv_begin = ( + remainder * (num_dst_per_sender + 1) + + (dp_rank - remainder) * num_dst_per_sender + ) + recv_end = recv_begin + num_dst_per_sender + + ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end)) + + model = self.worker.model_runner.get_model() + for new_worker_rank in sorted(ranks_to_send): + batch_transfer_weights( + model=model, + is_sender=True, + peer_rank=new_worker_rank, + dp_group=standby_dp_group, + expert_weights=model.expert_weights, + ) + torch.cuda.synchronize() + + def broadcast_expert_mapping(self) -> None: + standby_dp_group = get_standby_dp_group() + assert standby_dp_group is not None + model_config = self.worker.model_runner.model_config + eplb_state = self.worker.model_runner.eplb_state + assert eplb_state is not None + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + physical_to_logical = eplb_model_state.physical_to_logical_map + num_physical_experts = physical_to_logical.shape[1] + num_local_physical_experts = num_physical_experts // get_ep_group().world_size + num_logical_experts = eplb_model_state.logical_replica_count.shape[1] + broadcast_expert_mapping( + physical_to_logical=physical_to_logical, + num_local_physical_experts=num_local_physical_experts, + num_logical_experts=num_logical_experts, + dp_group=standby_dp_group, + src_rank=0, + device=self.worker.device, + ) + + def switch_and_remove(self) -> None: + _replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None) + + def switch_and_prepare(self) -> None: + old_dp_size = get_dp_group().world_size + old_ep_size = get_ep_group().world_size + + _replace_active_groups(**pop_standby_groups()) + + parallel_config = self.worker.vllm_config.parallel_config + reconfig_request = self.reconfig_request + assert reconfig_request is not None + new_dp_size = reconfig_request.new_data_parallel_size + new_ep_size = get_ep_group().world_size + + parallel_config.data_parallel_size = new_dp_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if ( + reconfig_request.new_data_parallel_rank_local + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank_local = ( + reconfig_request.new_data_parallel_rank_local + ) + parallel_config.data_parallel_master_ip = ( + reconfig_request.new_data_parallel_master_ip + ) + parallel_config.data_parallel_master_port = ( + reconfig_request.new_data_parallel_master_port + ) + + # Reconfigure MoE modules with new EP size + moe_modules = [ + module + for module in self.worker.model_runner.model.modules() + if ( + module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE" + ) + ] + num_local_experts = moe_modules[0].moe_config.num_local_experts + assert all( + module.moe_config.num_local_experts == num_local_experts + for module in moe_modules + ), "All MoE modules must have the same number of experts" + for module in moe_modules: + module.moe_config.num_experts = num_local_experts * new_ep_size + module.global_num_experts = module.moe_config.num_experts + tp_size = get_tp_group().world_size + is_sequence_parallel = parallel_config.use_sequence_parallel_moe + sp_size = tp_size if is_sequence_parallel else 1 + module.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=tp_size, + pcp_size_=get_pcp_group().world_size, + dp_size_=get_dp_group().world_size, + sp_size_=sp_size, + vllm_parallel_config=parallel_config, + ) + module.moe_config.moe_parallel_config = module.moe_parallel_config + + # Update EPLB state + eplb_state = self.worker.model_runner.eplb_state + assert eplb_state is not None + model_config = self.worker.model_runner.model_config + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + + num_physical_experts = num_local_experts * new_ep_size + num_logical_experts = eplb_model_state.logical_replica_count.shape[1] + parallel_config.eplb_config.num_redundant_experts = ( + num_physical_experts - num_logical_experts + ) + old_physical_to_logical = eplb_model_state.physical_to_logical_map + num_moe_layers = old_physical_to_logical.shape[0] + num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size + if new_dp_size > old_dp_size: + expanded_physical_to_logical = torch.full( + (num_moe_layers, num_local_experts * new_ep_size), + -1, + dtype=old_physical_to_logical.dtype, + device=old_physical_to_logical.device, + ) + expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = ( + old_physical_to_logical + ) + eplb_model_state.physical_to_logical_map = expanded_physical_to_logical + + old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1] + pad_size = num_physical_experts - old_num_physical_experts + if new_dp_size > old_dp_size: + assert pad_size > 0 + expanded_expert_load_pass = F.pad( + eplb_model_state.expert_load_pass, (0, pad_size), value=0 + ) + expanded_expert_load_window = F.pad( + eplb_model_state.expert_load_window, (0, pad_size), value=0 + ) + eplb_model_state.expert_load_pass = expanded_expert_load_pass + eplb_model_state.expert_load_window = expanded_expert_load_window + eplb_state.num_valid_physical_experts = old_num_physical_experts + else: + assert pad_size < 0 + eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[ + :, :num_physical_experts + ] + eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[ + :, :, :num_physical_experts + ] + eplb_state.num_valid_physical_experts = num_physical_experts + + model = self.worker.model_runner.get_model() + model.expert_weights = [] + with set_current_vllm_config(self.worker.vllm_config): + model.set_eplb_state( + eplb_model_state.expert_load_pass, + eplb_model_state.logical_to_physical_map, + eplb_model_state.logical_replica_count, + ) + model.update_physical_experts_metadata( + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_experts, + ) + # Force re-creation of the modular kernel (and all2all manager) + # for the new EP size by resetting quant_method to base + for module in moe_modules: + if hasattr(module.quant_method, "old_quant_method"): + module.quant_method = module.quant_method.old_quant_method + module.runner = module._init_runner() + prepare_communication_buffer_for_model(self.worker.model_runner.model) + if ( + self.worker.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE + ): + # NOTE(yongji): when using stock torch.compile, + # torch.compile is triggered during GPUModelRunner's load_model() + # TODO(yongji):check do we need to re-trigger torch.compile here? + # any changes to the tensor shapes in execution should already + # be handled internally by torch.compile. + backend = self.worker.vllm_config.compilation_config.init_backend( + self.worker.vllm_config + ) + compilation_counter.stock_torch_compile_count += 1 + self.worker.model_runner.model.compile(fullgraph=True, backend=backend) + + # release all previously captured CUDA graphs + if isinstance(self.worker.model_runner.model, CUDAGraphWrapper): + wrapper = self.worker.model_runner.model + wrapper.concrete_cudagraph_entries = {} + elif isinstance(self.worker.model_runner.model, UBatchWrapper): + raise RuntimeError("DBO is not yet supported in elastic EP") + + multi_block_table = self.worker.model_runner.input_batch.block_table + saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = [] + for bt in multi_block_table.block_tables: + saved_block_tables.append( + (bt.block_table.gpu.clone(), bt.block_table.cpu.clone()) + ) + multi_block_table.clear() + + # reset the compile wrapper + torch.compiler.reset() + with set_current_vllm_config(self.worker.vllm_config): + reset_compile_wrapper(self.worker.model_runner.get_model()) + + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + unlock_workspace() + self.worker.compile_or_warm_up_model() + lock_workspace() + + for bt, (saved_gpu, saved_cpu) in zip( + multi_block_table.block_tables, saved_block_tables + ): + bt.block_table.gpu.copy_(saved_gpu) + bt.block_table.cpu.copy_(saved_cpu) + + def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None: + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding...") + + eplb_state = self.worker.model_runner.eplb_state + assert eplb_state is not None + + model_config = self.worker.model_runner.model_config + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + is_async_enabled = eplb_state.is_async + eplb_state.is_async = False + if new_dp_size is None: + eplb_state.rearrange() + else: + # scale down + parallel_config = self.worker.vllm_config.parallel_config + tp_size = parallel_config.tensor_parallel_size + old_ep_size = parallel_config.data_parallel_size * tp_size + new_ep_size = new_dp_size * tp_size + + rank_mapping = { + old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 + for old_ep_rank in range(old_ep_size) + } + + eplb_state.rearrange(rank_mapping=rank_mapping) + # NOTE(yongji): check whether we need to synchronize here + torch.cuda.synchronize() + # reset expert_rearrangement_step to ensure all ranks are synchronized + eplb_state.expert_rearrangement_step = 0 + eplb_state.num_valid_physical_experts = ( + eplb_model_state.physical_to_logical_map.shape[1] + ) + eplb_state.is_async = is_async_enabled + self.worker.model_runner.eep_eplb_suppressed = False + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed") + + def receive_weights(self) -> None: + dp_group = get_dp_group() + assert isinstance(dp_group, StatelessGroupCoordinator) + new_dp_size = dp_group.world_size + dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank + + # Receive old_dp_size broadcasted during transfer_weights + old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0) + old_dp_size = int(old_dp_size_tensor[0].item()) + + # Calculate which existing worker will send to this new worker + num_new_workers = new_dp_size - old_dp_size + new_worker_idx = dp_rank - old_dp_size + num_dst_per_sender = num_new_workers // old_dp_size + remainder = num_new_workers % old_dp_size + + if new_worker_idx < remainder * (num_dst_per_sender + 1): + sender_rank = new_worker_idx // (num_dst_per_sender + 1) + else: + sender_rank = ( + remainder + + (new_worker_idx - remainder * (num_dst_per_sender + 1)) + // num_dst_per_sender + ) + + model = self.worker.model_runner.get_model() + batch_transfer_weights( + model=model, + is_sender=False, + peer_rank=sender_rank, + dp_group=dp_group, + expert_weights=model.expert_weights, + ) + torch.cuda.synchronize() + + def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]: + dp_group = get_dp_group() + assert isinstance(dp_group, StatelessGroupCoordinator) + physical_to_logical, num_local_physical_experts, num_logical_experts = ( + broadcast_expert_mapping( + physical_to_logical=None, + num_local_physical_experts=None, + num_logical_experts=None, + dp_group=dp_group, + src_rank=0, + device=self.worker.device, + ) + ) + num_moe_layers = physical_to_logical.shape[0] + new_dp_size = get_dp_group().world_size + tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size + new_ep_size = new_dp_size * tp_size + expanded_physical_to_logical = torch.full( + (num_moe_layers, num_local_physical_experts * new_ep_size), + -1, + dtype=physical_to_logical.dtype, + device=physical_to_logical.device, + ) + old_num_physical_experts = physical_to_logical.shape[1] + expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical + return ( + expanded_physical_to_logical, + num_logical_experts, + old_num_physical_experts, + ) + + def prepare_new_worker(self) -> None: + with set_current_vllm_config(self.worker.vllm_config): + prepare_communication_buffer_for_model(self.worker.model_runner.get_model()) diff --git a/vllm/distributed/elastic_ep/elastic_state.py b/vllm/distributed/elastic_ep/elastic_state.py new file mode 100644 index 000000000..4845a16f1 --- /dev/null +++ b/vllm/distributed/elastic_ep/elastic_state.py @@ -0,0 +1,563 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum +import time +import weakref +from datetime import timedelta +from typing import TYPE_CHECKING, Literal + +import torch.distributed + +from vllm.config import ParallelConfig +from vllm.distributed import ( + sched_yield, + stateless_destroy_torch_distributed_process_group, +) +from vllm.logger import init_logger +from vllm.v1.engine import ( + EEPNotificationType, + ReconfigureDistributedRequest, + ReconfigureRankType, +) +from vllm.v1.engine.core import DPEngineCoreProc + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.executor.abstract import Executor + +logger = init_logger(__name__) + +WorkerType = Literal["existing", "new", "removing"] + + +class ScaleUpExistingEngineState(enum.IntEnum): + WAIT_NEW_CORE_ENGINES_INIT = 0 + CREATE_STANDBY_GROUPS = 1 + TRANSFER_EXPERT_MAPPING = 2 + WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3 + TRANSFER_WEIGHTS = 4 + SYNC_KV_CACHE_MEMORY_SIZE = 5 + SWITCH_AND_PREPARE = 6 + EPLB_RESHUFFLE = 7 + COMPLETE = 8 + + +class ScaleUpNewEngineState(enum.IntEnum): + PREPARE = 0 + EPLB_RESHUFFLE = 1 + COMPLETE = 2 + + +class ScaleDownRemainingEngineState(enum.IntEnum): + PREPARE = 0 + EPLB_RESHUFFLE = 1 + SWITCH_AND_PREPARE = 2 + COMPLETE = 3 + + +class ScaleDownRemovingEngineState(enum.IntEnum): + PREPARE = 0 + EPLB_RESHUFFLE = 1 + COMPLETE = 2 + + +class _BarrierTimeoutError(RuntimeError): + """ + Exception raised for timeout + in the first stage of our two-staged + TCPStore based barrier to synchronize the + execution of all engines in the DP group. + """ + + +class ElasticEPScalingState: + def __init__( + self, + model_executor: "Executor", + engine_core: "DPEngineCoreProc", + vllm_config: "VllmConfig", + new_parallel_config: ParallelConfig, + worker_type: WorkerType, + scale_type: Literal["scale_up", "scale_down"], + reconfig_request: ReconfigureDistributedRequest | None = None, + ): + self.model_executor_ref = weakref.ref(model_executor) + self.engine_core_ref = weakref.ref(engine_core) + self.vllm_config = vllm_config + self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None + self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None + self.new_parallel_config: ParallelConfig = new_parallel_config + self.new_dp_group: torch.distributed.ProcessGroup | None = ( + self.engine_core.dp_group if worker_type == "new" else None + ) + self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None + self.worker_type = worker_type + self.scale_type = scale_type + self.reconfig_request = reconfig_request + + if scale_type == "scale_up": + self.state = ( + ScaleUpNewEngineState.PREPARE + if worker_type == "new" + else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT + ) + else: + self.state = ( + ScaleDownRemovingEngineState.PREPARE + if worker_type == "removing" + else ScaleDownRemainingEngineState.PREPARE + ) + + @property + def model_executor(self) -> "Executor": + model_executor = self.model_executor_ref() + if model_executor is None: + raise RuntimeError("Model executor has been garbage collected") + return model_executor + + @property + def engine_core(self) -> "DPEngineCoreProc": + engine_core = self.engine_core_ref() + if engine_core is None: + raise RuntimeError("Engine core has been garbage collected") + return engine_core + + def progress(self) -> bool: + if self.scale_type == "scale_up": + return ( + self._progress_new_engine() + if self.worker_type == "new" + else self._progress_existing_engine() + ) + return ( + self._progress_removing_engine() + if self.worker_type == "removing" + else self._progress_remaining_engine() + ) + + def _execute_tcp_store_barrier( + self, dp_store, group_rank, group_size, barrier_id, timeout=None + ): + arrival_key = f"arrival_{barrier_id}_{group_rank}" + dp_store.set(arrival_key, b"1") + + start_time = time.time() + processes_arrived: set[int] = set() + + while len(processes_arrived) < group_size: + if ( + timeout is not None + and time.time() - start_time > timeout.total_seconds() + ): + raise _BarrierTimeoutError( + f"Barrier timed out after {timeout.total_seconds()} seconds" + ) + + for i in range(group_size): + if i in processes_arrived: + continue + + key = f"arrival_{barrier_id}_{i}" + present = dp_store.check([key]) + if present: + processes_arrived.add(i) + + if len(processes_arrived) < group_size: + sched_yield() + + def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool: + """ + Execute a two-staged barrier to synchronize all engines in the DP group. + + Some DP EngineCores may receive the reconfiguration notifications + later than others, and already proceed to engine step (model forward) + in the busy loop. + In this case, EngineCores that already proceed to reconfiguration + should skip reconfiguration and execute model forward for one more + step, so in the next step, all EngineCores will be synchronized. + We use a two-staged barrier to achieve this. The first time each + EngineCore executes the barrier, if a timeout is reached before the + barrier completes, that means some EngineCores have already entered + engine step. The EngineCores that timed out will then proceed to + engine step, and will synchronize with the other EngineCores in the + next step with a barrier without timeout. + """ + dp_store = self.new_dp_store if use_new_group else self.old_dp_store + dp_group = self.new_dp_group if use_new_group else self.old_dp_group + assert dp_group is not None + + group_rank = dp_group.rank() + group_size = dp_group.size() + barrier_id = f"eep_barrier_{barrier_name}" + sync_key = f"{barrier_id}_sync" + + # TODO(yongji): figure out appropriate timeout for the barrier + timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5) + + try: + self._execute_tcp_store_barrier( + dp_store, group_rank, group_size, barrier_id, timeout=timeout + ) + torch.distributed.barrier(dp_group) + if group_rank == 0: + dp_store.delete_key(sync_key) + for i in range(group_size): + dp_store.delete_key(f"arrival_{barrier_id}_{i}") + return True + except _BarrierTimeoutError as e: + if timeout is None: + raise RuntimeError("Unexpected timeout encountered") from e + dp_store.compare_set(sync_key, "", b"1") + return False + + def _progress_existing_engine(self) -> bool: + state = self.state + + if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT: + return False + + elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS: + # NOTE(yongji): wait for all existing workers to receive the request + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=False, barrier_name="create_standby_groups" + ): + return False + if self.old_dp_group.rank() == 0: + self.old_dp_store.delete_key("eep_barrier_engine_count") + self._create_standby_groups() + self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING + return True + + elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING: + self._transfer_expert_mapping() + self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT + return True + + elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT: + return False + + elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS: + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=False, barrier_name="transfer_weights" + ): + return False + if self.old_dp_group.rank() == 0: + self.old_dp_store.delete_key("eep_barrier_engine_count") + self._transfer_weights() + self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE + return True + + elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE: + self._sync_kv_cache_memory_size() + self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE + return True + + elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE: + self._switch_and_prepare() + self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE + self.new_dp_store.add("eep_barrier_engine_count", 1) + return True + + elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE: + assert self.new_dp_group is not None + if ( + int(self.new_dp_store.get("eep_barrier_engine_count")) + < self.new_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=True, barrier_name="eplb_reshuffle" + ): + return False + if self.new_dp_group.rank() == 0: + self.new_dp_store.delete_key("eep_barrier_engine_count") + self._eplb_reshuffle() + self.state = ScaleUpExistingEngineState.COMPLETE + self._update_parallel_config() + return True + + else: + assert self.state == ScaleUpExistingEngineState.COMPLETE + return True + + def _progress_new_engine(self) -> bool: + state = self.state + assert self.new_dp_group is not None + + if state == ScaleUpNewEngineState.PREPARE: + tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu") + torch.distributed.all_reduce( + tensor, + op=torch.distributed.ReduceOp.MAX, + group=self.new_dp_group, + ) + data = tensor.tolist() + self.engine_core.engines_running = bool(data[0]) + self.engine_core.current_wave = int(data[1]) + self.engine_core.step_counter = int(data[2]) + self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE + self.new_dp_store.add("eep_barrier_engine_count", 1) + return True + + elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE: + if ( + int(self.new_dp_store.get("eep_barrier_engine_count")) + < self.new_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=True, barrier_name="eplb_reshuffle" + ): + return False + assert self.new_dp_group.rank() > 0 + self._eplb_reshuffle() + self.state = ScaleUpNewEngineState.COMPLETE + return True + + else: + assert self.state == ScaleUpNewEngineState.COMPLETE + return True + + def _progress_remaining_engine(self) -> bool: + state = self.state + + if state == ScaleDownRemainingEngineState.PREPARE: + self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE + self.old_dp_store.add("eep_barrier_engine_count", 1) + return True + + elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE: + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=False, barrier_name="eplb_reshuffle" + ): + return False + if self.old_dp_group.rank() == 0: + self.old_dp_store.delete_key("eep_barrier_engine_count") + self._eplb_reshuffle_before_scale_down() + self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE + # NOTE(yongji): currently, after EPLB reshuffle + # that redistributes experts to remaining workers, workers + # to be removed will immediately initiate shutdown; + # existing workers can no longer execute forward steps using + # the old setup. In the future, we may keep + # the removing workers alive a bit longer, + # e.g., to drain in-batch requests. + self._create_standby_groups() + self._switch_and_prepare() + self._update_parallel_config() + self.state = ScaleDownRemainingEngineState.COMPLETE + return True + + else: + assert self.state == ScaleDownRemainingEngineState.COMPLETE + return True + + def _progress_removing_engine(self) -> bool: + state = self.state + + if state == ScaleDownRemovingEngineState.PREPARE: + self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE + self.old_dp_store.add("eep_barrier_engine_count", 1) + return True + + if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE: + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=False, barrier_name="eplb_reshuffle" + ): + return False + assert self.old_dp_group.rank() > 0 + self._eplb_reshuffle_before_scale_down() + self._switch_and_remove() + self.state = ScaleDownRemovingEngineState.COMPLETE + self.engine_core._eep_send_engine_core_notification( + EEPNotificationType.SHUTDOWN_COMPLETE + ) + self.engine_core.shutdown() + return True + + else: + assert self.state == ScaleDownRemovingEngineState.COMPLETE + return True + + def handle_notification(self, notification_type: EEPNotificationType): + assert self.worker_type != "new" + if ( + notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY + and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT + ): + self.old_dp_store.add("eep_barrier_engine_count", 1) + self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS + elif ( + notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY + and self.state + == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT + ): + self.old_dp_store.add("eep_barrier_engine_count", 1) + self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS + + def is_complete(self) -> bool: + if self.scale_type == "scale_up": + return ( + self.state == ScaleUpNewEngineState.COMPLETE + if self.worker_type == "new" + else self.state == ScaleUpExistingEngineState.COMPLETE + ) + return ( + self.state == ScaleDownRemovingEngineState.COMPLETE + if self.worker_type == "removing" + else self.state == ScaleDownRemainingEngineState.COMPLETE + ) + + def _create_standby_groups(self): + self.new_dp_group, self.new_dp_store = ( + self.new_parallel_config.stateless_init_dp_group(return_store=True) + ) + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("create_standby_groups", self.reconfig_request) + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Created standby communication groups") + + def _transfer_weights(self): + assert self.reconfig_request is not None + old_dp_size = self.old_dp_group.size() + new_dp_size = self.reconfig_request.new_data_parallel_size + + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size) + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Transferred weights to new workers") + + def _transfer_expert_mapping(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("broadcast_expert_mapping",) + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Broadcasted expert mapping to new workers") + + def _sync_kv_cache_memory_size(self): + assert self.engine_core.available_gpu_memory_for_kv_cache > 0 + assert self.new_dp_group is not None + ParallelConfig.sync_kv_cache_memory_size( + self.new_dp_group, + self.engine_core.available_gpu_memory_for_kv_cache, + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Synced KV cache memory size to new workers") + + def _switch_and_prepare(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("switch_and_prepare",) + ) + old_dp_group = self.old_dp_group + stateless_destroy_torch_distributed_process_group(old_dp_group) + assert self.new_dp_group is not None + new_dp_group = self.new_dp_group + self.engine_core.dp_group = new_dp_group + self.engine_core.dp_rank = new_dp_group.rank() + self.engine_core.dp_store = self.new_dp_store + engines_running = int(self.engine_core.engines_running) + current_wave = self.engine_core.current_wave + step_counter = self.engine_core.step_counter + tensor = torch.tensor( + [engines_running, current_wave, step_counter], + dtype=torch.int32, + device="cpu", + ) + torch.distributed.all_reduce( + tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group + ) + data = tensor.tolist() + self.engine_core.engines_running = bool(data[0]) + self.engine_core.current_wave = int(data[1]) + self.engine_core.step_counter = int(data[2]) + if new_dp_group.rank() == 0: + self.engine_core._eep_send_engine_core_notification( + EEPNotificationType.RECONFIGURE_FINISHED + ) + logger.info("[Elastic EP] Switched to new setup") + + def _eplb_reshuffle(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("perform_eplb_reshuffle",) + ) + assert self.new_dp_group is not None + if self.new_dp_group.rank() == 0: + logger.info("[Elastic EP] EPLB reshuffle completed") + + def _eplb_reshuffle_before_scale_down(self): + assert self.reconfig_request is not None + self.model_executor.collective_rpc( + "elastic_ep_execute", + args=( + "perform_eplb_reshuffle", + self.reconfig_request.new_data_parallel_size, + ), + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] EPLB reshuffle completed") + + def _switch_and_remove(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("switch_and_remove",) + ) + + def _update_parallel_config(self): + assert self.reconfig_request is not None + reconfig_request = self.reconfig_request + parallel_config = self.vllm_config.parallel_config + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if ( + reconfig_request.new_data_parallel_rank_local + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank_local = ( + reconfig_request.new_data_parallel_rank_local + ) + parallel_config.data_parallel_master_ip = ( + reconfig_request.new_data_parallel_master_ip + ) + parallel_config.data_parallel_master_port = ( + reconfig_request.new_data_parallel_master_port + ) + parallel_config._data_parallel_master_port_list = ( + reconfig_request.new_data_parallel_master_port_list + ) + parallel_config._stateless_world_group_port_list = ( + reconfig_request.new_stateless_world_group_port_list + ) + parallel_config._stateless_dp_group_port_list = ( + reconfig_request.new_stateless_dp_group_port_list + ) + parallel_config._stateless_ep_group_port_list = ( + reconfig_request.new_stateless_ep_group_port_list + ) + parallel_config._stateless_eplb_group_port_list = ( + reconfig_request.new_stateless_eplb_group_port_list + ) diff --git a/vllm/distributed/elastic_ep/standby_state.py b/vllm/distributed/elastic_ep/standby_state.py new file mode 100644 index 000000000..d11e0b550 --- /dev/null +++ b/vllm/distributed/elastic_ep/standby_state.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.distributed.parallel_state import ( + _init_stateless_group, + _node_count, + get_pp_group, + get_tp_group, + get_world_group, +) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + +_STANDBY_WORLD: StatelessGroupCoordinator | None = None +_STANDBY_WORLD_NODE_COUNT: int | None = None +_STANDBY_DP: StatelessGroupCoordinator | None = None +_STANDBY_EP: StatelessGroupCoordinator | None = None +_STANDBY_EPLB: StatelessGroupCoordinator | None = None + + +def get_standby_dp_group() -> StatelessGroupCoordinator | None: + return _STANDBY_DP + + +def get_standby_ep_group() -> StatelessGroupCoordinator | None: + return _STANDBY_EP + + +def get_standby_eplb_group() -> StatelessGroupCoordinator | None: + return _STANDBY_EPLB + + +def get_standby_world_group() -> StatelessGroupCoordinator | None: + return _STANDBY_WORLD + + +def create_standby_groups( + new_dp_size: int, + new_world_size_across_dp: int, + master_ip: str, + world_group_ports: list[list[int]], + dp_group_ports: list[list[int]], + ep_group_ports: list[list[int]], + eplb_group_ports: list[list[int]] | None = None, + backend: str | None = None, +) -> None: + global \ + _STANDBY_WORLD, \ + _STANDBY_WORLD_NODE_COUNT, \ + _STANDBY_DP, \ + _STANDBY_EP, \ + _STANDBY_EPLB + + assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size + world_group = get_world_group() + assert isinstance(world_group, StatelessGroupCoordinator) + backend = backend or world_group.backend + + standby_world_ranks = [list(range(new_world_size_across_dp))] + _STANDBY_WORLD = _init_stateless_group( + standby_world_ranks, + "world", + world_group_ports, + master_ip, + backend, + use_device_communicator=False, + ) + _STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group) + + tp_size = get_tp_group().world_size + pp_size = get_pp_group().world_size + + all_ranks = torch.arange(new_world_size_across_dp).reshape( + -1, new_dp_size, pp_size, tp_size + ) + standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0) + standby_dp_ranks = [x.tolist() for x in standby_dp_ranks] + _STANDBY_DP = _init_stateless_group( + standby_dp_ranks, "dp", dp_group_ports, master_ip, backend + ) + + standby_ep_ranks = ( + all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0) + ) + standby_ep_ranks = [x.tolist() for x in standby_ep_ranks] + _STANDBY_EP = _init_stateless_group( + standby_ep_ranks, "ep", ep_group_ports, master_ip, backend + ) + + if eplb_group_ports is not None: + _STANDBY_EPLB = _init_stateless_group( + standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend + ) + + +def pop_standby_groups() -> dict: + """Return all standby groups and clear the standby state.""" + global \ + _STANDBY_WORLD, \ + _STANDBY_WORLD_NODE_COUNT, \ + _STANDBY_DP, \ + _STANDBY_EP, \ + _STANDBY_EPLB + + result = dict( + world=_STANDBY_WORLD, + dp=_STANDBY_DP, + ep=_STANDBY_EP, + eplb=_STANDBY_EPLB, + node_count=_STANDBY_WORLD_NODE_COUNT, + ) + _STANDBY_WORLD = None + _STANDBY_WORLD_NODE_COUNT = None + _STANDBY_DP = None + _STANDBY_EP = None + _STANDBY_EPLB = None + return result diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index b81c7fa9c..5dd862f36 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -24,7 +24,6 @@ logger = init_logger(__name__) def start_async_worker( state: "EplbState", - rank_mapping: dict[int, int] | None = None, is_profile: bool = False, ) -> threading.Thread: eplb_group = get_eplb_group().device_group @@ -45,7 +44,6 @@ def start_async_worker( eplb_group=eplb_group, cuda_stream=cuda_stream, is_profile=is_profile, - rank_mapping=rank_mapping, ) ) except Exception as exc: # pragma: no cover - diagnostic path @@ -107,7 +105,6 @@ async def transfer_run_periodically( eplb_group: ProcessGroup, cuda_stream: torch.cuda.Stream, is_profile: bool = False, - rank_mapping: dict[int, int] | None = None, ) -> None: while True: await asyncio.to_thread(state.rearrange_event.wait) @@ -176,7 +173,6 @@ async def transfer_run_periodically( ep_group=eplb_group, is_profile=is_profile, cuda_stream=cuda_stream, - rank_mapping=rank_mapping, ) event = torch.cuda.Event(blocking=False) cuda_stream.record_event(event) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 891f19cfe..b417c2b32 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import ( get_node_count, in_the_same_node_as, ) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -302,6 +303,14 @@ class EplbState: """ CUDA device index for the async EPLB worker thread. """ + self.num_valid_physical_experts: int = 0 + """ + Number of valid physical experts. + This is the number of physical experts that are + actually mapped to logical experts. In elastic EP, + newly started EP ranks may not have physical experts + mapped yet. + """ if self.device.type == "cuda": self.cuda_device_index = self.device.index if self.cuda_device_index is None and torch.cuda.is_available(): @@ -367,9 +376,6 @@ class EplbState: self, model: MixtureOfExperts, model_config: ModelConfig, - global_expert_load: torch.Tensor | None = None, - old_global_expert_indices: torch.Tensor | None = None, - rank_mapping: dict[int, int] | None = None, ): """ Build the initial EPLB state. @@ -462,75 +468,15 @@ class EplbState: ) self.expert_rearrangement_step_interval = eplb_step_interval - # Set the policy based on the selected eplb algorithm type. policy_type = self.parallel_config.eplb_config.policy self.policy = EPLB_POLICIES[policy_type] logger.debug("Selected EPLB policy: %s", policy_type) - if global_expert_load is not None: - ep_group = get_ep_group().device_group - assert global_expert_load.shape == ( - model.num_moe_layers, - model.num_logical_experts, - ) - assert global_expert_load.dtype == torch.int64 - num_replicas = model.num_physical_experts - num_groups = model.num_expert_groups - num_nodes = get_node_count() - num_gpus = ep_group.size() - - if num_gpus % num_nodes != 0: - num_nodes = 1 - logger.warning_once( - f"num_gpus % num_nodes != 0, " - "not using hierarchical rearrangement algorithm.\n" - f"{num_gpus=}, {num_nodes=}" - ) - - # Get new expert mappings - ( - new_physical_to_logical_map, - new_logical_to_physical_map, - new_logical_replica_count, - ) = self.policy.rebalance_experts( - global_expert_load, - num_replicas, - num_groups, - num_nodes, - num_gpus, - ) - - max_physical_slots = new_logical_to_physical_map.shape[-1] - assert max_physical_slots <= logical_to_physical_map.shape[-1] - new_logical_to_physical_map = torch.nn.functional.pad( - new_logical_to_physical_map, - (0, logical_to_physical_map.shape[-1] - max_physical_slots), - value=-1, - ) - physical_to_logical_map = new_physical_to_logical_map.to(self.device) - logical_to_physical_map.copy_(new_logical_to_physical_map) - logical_replica_count.copy_(new_logical_replica_count) - else: - new_physical_to_logical_map = None - - new_logical_to_physical_map = None - - new_logical_replica_count = None model.set_eplb_state( expert_load_pass, logical_to_physical_map, logical_replica_count, ) - if global_expert_load is not None: - rearrange_expert_weights_inplace( - old_global_expert_indices, - new_physical_to_logical_map, - model.expert_weights, - ep_group, - False, - rank_mapping, - ) - self.expert_rearrangement_step = 0 expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]] @@ -561,11 +507,12 @@ class EplbState: recv_dst_rows=np.array([]), ), cuda_device_index=self.cuda_device_index, - new_physical_to_logical_map=new_physical_to_logical_map, - new_logical_to_physical_map=new_logical_to_physical_map, - new_logical_replica_count=new_logical_replica_count, + new_physical_to_logical_map=None, + new_logical_to_physical_map=None, + new_logical_replica_count=None, ) self.model_states[model_config.compute_hash()] = model_state + self.num_valid_physical_experts = model.num_physical_experts def step( self, @@ -696,8 +643,6 @@ class EplbState: def rearrange( self, is_profile: bool = False, - execute_shuffle: bool = True, - global_expert_loads: list[torch.Tensor] | None = None, rank_mapping: dict[int, int] | None = None, ) -> torch.Tensor | None: """ @@ -707,12 +652,6 @@ class EplbState: is_profile (bool): If `True`, perform a dummy rearrangement. This is used in `profile_run` to reserve enough memory, no memory movement will be performed. Default is False. - execute_shuffle (bool): If `True`, execute the shuffle - in elastic expert parallel (EEP). Default is True. - global_expert_loads (list[torch.Tensor] | None): The global expert - loads when scaling is done in EEP. - List of expert loads for the main and drafter - (when spec decode is used) models. rank_mapping (dict[int, int] | None): The rank mapping when scaling is done in EEP. """ @@ -734,67 +673,34 @@ class EplbState: "(profile)" if is_profile else "", ) - if global_expert_loads is None: - # Map the physical expert load to global logical experts - global_expert_load_windows = [] - if not execute_shuffle: - num_models = torch.tensor( - [len(self.model_states)], dtype=torch.int32, device="cpu" - ) - torch.distributed.broadcast( - num_models, group=get_ep_group().cpu_group, group_src=0 - ) - - for eplb_model_state in self.model_states.values(): - logical_expert_load_window = torch.zeros( - self.expert_load_window_size, - eplb_model_state.model.num_moe_layers, - eplb_model_state.model.num_logical_experts, - dtype=eplb_model_state.expert_load_window.dtype, - device=eplb_model_state.expert_load_window.device, - ) - logical_expert_load_window.scatter_add_( - dim=-1, - index=eplb_model_state.physical_to_logical_map.unsqueeze(0) - .expand_as(eplb_model_state.expert_load_window) - .long(), - src=eplb_model_state.expert_load_window, - ) - - if not execute_shuffle: - metadata = torch.tensor( - [ - eplb_model_state.model.num_moe_layers, - eplb_model_state.model.num_logical_experts, - eplb_model_state.physical_to_logical_map.shape[1], - ], - dtype=torch.int32, - device="cpu", - ) - torch.distributed.broadcast( - metadata, group=get_ep_group().cpu_group, group_src=0 - ) - - global_expert_load_window = logical_expert_load_window.sum(dim=0) - global_expert_load_windows.append(global_expert_load_window) - # Perform all-reduce to get the expert load across all ranks for each model - global_expert_load_windows = self._allreduce_list( - global_expert_load_windows + # Map the physical expert load to global logical experts + global_expert_load_windows = [] + for eplb_model_state in self.model_states.values(): + expert_load_window = eplb_model_state.expert_load_window[ + :, :, : self.num_valid_physical_experts + ] + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + eplb_model_state.model.num_moe_layers, + eplb_model_state.model.num_logical_experts, + dtype=eplb_model_state.expert_load_window.dtype, + device=eplb_model_state.expert_load_window.device, ) - if not execute_shuffle: - for eplb_model_state, global_expert_load_window in zip( - self.model_states.values(), global_expert_load_windows - ): - # (num_moe_layers, old_num_physical_experts) - old_global_expert_indices = eplb_model_state.physical_to_logical_map - torch.distributed.broadcast( - old_global_expert_indices, group=ep_group, group_src=0 - ) - if not execute_shuffle: - return global_expert_load_windows - else: - assert execute_shuffle - global_expert_load_windows = global_expert_loads + logical_expert_load_window.scatter_add_( + dim=-1, + index=eplb_model_state.physical_to_logical_map[ + :, : self.num_valid_physical_experts + ] + .unsqueeze(0) + .expand_as(expert_load_window) + .long(), + src=expert_load_window, + ) + + global_expert_load_window = logical_expert_load_window.sum(dim=0) + global_expert_load_windows.append(global_expert_load_window) + # Perform all-reduce to get the expert load across all ranks for each model + global_expert_load_windows = self._allreduce_list(global_expert_load_windows) # TODO(bowen): Treat differently for prefill and decode nodes eplb_model_state = next(iter(self.model_states.values())) @@ -806,8 +712,10 @@ class EplbState: # NOTE(yongji): scale down, we need to rebalance the experts on # remaining GPUs, transfer the experts while we haven't shutdown # the GPUs to be released. - cpu_group = get_ep_group().cpu_group - num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping) + coordinator = get_ep_group() + assert isinstance(coordinator, StatelessGroupCoordinator) + tcp_store_group = coordinator.tcp_store_group + num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping) num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values()) num_replicas = ( num_replicas // ep_group.size() * num_gpus @@ -933,7 +841,6 @@ class EplbState: if self.async_worker is None: self.async_worker = start_async_worker( self, - rank_mapping=rank_mapping, is_profile=is_profile, ) @@ -1089,83 +996,6 @@ class EplbState: model_state.new_logical_to_physical_map = None model_state.new_logical_replica_count = None - @staticmethod - def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]: - """ - Receive the expert load and old placement from the master rank. - """ - ep_group = get_ep_group() - num_models = torch.empty(1, dtype=torch.int32, device="cpu") - torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0) - num_models = num_models.item() - global_expert_loads = [] - old_global_expert_indices_per_model = [] - for _ in range(num_models): - metadata = torch.empty(3, dtype=torch.int32, device="cpu") - torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0) - num_moe_layers, num_logical_experts, num_old_physical_experts = ( - metadata.tolist() - ) - global_expert_load = torch.zeros( - (num_moe_layers, num_logical_experts), - dtype=torch.int64, - device=ep_group.device, - ) - all_reduce(global_expert_load, group=ep_group.device_group) - old_global_expert_indices = torch.empty( - (num_moe_layers, num_old_physical_experts), - dtype=torch.int64, - device=ep_group.device, - ) - torch.distributed.broadcast( - old_global_expert_indices, - group=ep_group.device_group, - group_src=0, - ) - global_expert_loads.append(global_expert_load) - old_global_expert_indices_per_model.append(old_global_expert_indices) - return global_expert_loads, old_global_expert_indices_per_model - - @classmethod - def get_eep_state( - cls, parallel_config: ParallelConfig - ) -> tuple[ - list[torch.Tensor] | None, - list[torch.Tensor] | None, - dict[int, int] | None, - ]: - num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") - torch.distributed.broadcast( - num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0, - ) - num_local_physical_experts = int(num_local_physical_experts.item()) - new_ep_size = get_ep_group().world_size - global_expert_loads, old_global_expert_indices_per_model = ( - EplbState.recv_state() - ) - - # EP configuration for all models has to be the same so as eplb config - num_logical_experts = global_expert_loads[0].shape[1] - parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts - ) - assert ( - old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts - == 0 - ) - old_ep_size = ( - old_global_expert_indices_per_model[0].shape[1] - // num_local_physical_experts - ) - rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} - return ( - global_expert_loads, - old_global_expert_indices_per_model, - rank_mapping, - ) - def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]: """ All-reduce a list of tensors. @@ -1203,6 +1033,60 @@ class EplbState: load_pass_list.append(eplb_model_state.expert_load_pass.clone()) return self._allreduce_list(load_pass_list) + @classmethod + def from_mapping( + cls, + model: MixtureOfExperts, + model_config: ModelConfig, + device: torch.device, + parallel_config: ParallelConfig, + expanded_physical_to_logical: torch.Tensor, + num_valid_physical_experts: int, + ) -> "EplbState": + eplb_state = cls( + parallel_config=parallel_config, + device=device, + ) + eplb_state.add_model( + model=model, + model_config=model_config, + ) + eplb_state.num_valid_physical_experts = num_valid_physical_experts + num_moe_layers = expanded_physical_to_logical.shape[0] + num_physical_experts = expanded_physical_to_logical.shape[1] + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical) + + logical_to_physical_map = torch.full( + ( + num_moe_layers, + model.num_logical_experts, + eplb_model_state.logical_to_physical_map.shape[2], + ), + -1, + dtype=torch.int64, + ) + logical_replica_count = torch.zeros( + (num_moe_layers, model.num_logical_experts), + dtype=torch.int64, + ) + expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy() + for layer_idx in range(num_moe_layers): + for phys_idx in range(num_physical_experts): + logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx] + if logical_idx >= 0: + replica_idx = logical_replica_count[layer_idx, logical_idx] + logical_to_physical_map[layer_idx, logical_idx, replica_idx] = ( + phys_idx + ) + logical_replica_count[layer_idx, logical_idx] += 1 + + logical_to_physical_map = logical_to_physical_map.to(device) + logical_replica_count = logical_replica_count.to(device) + eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map) + eplb_model_state.logical_replica_count.copy_(logical_replica_count) + return eplb_state + @dataclass class EplbLayerState: diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 1be1e2483..777f9c553 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -19,6 +19,8 @@ from torch.distributed import ( get_global_rank, ) +from vllm.distributed.parallel_state import get_ep_group +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.logger import init_logger logger = init_logger(__name__) @@ -249,10 +251,18 @@ def move_to_buffer( b[dst].copy_(w[src_local], non_blocking=True) p2p_ops: list[P2POp] = [] + if isinstance(get_ep_group(), StatelessGroupCoordinator): + ep_group = get_ep_group() + is_stateless = True + else: + is_stateless = False - # Pre-compute global ranks mapping + # Pre-compute global ranks mapping (only needed for non-stateless groups) ep_size = ep_group.size() - rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)} + if not is_stateless: + rank_to_global = { + rank: get_global_rank(ep_group, rank) for rank in range(ep_size) + } # 2. Post sends if send_count > 0: @@ -284,15 +294,23 @@ def move_to_buffer( if recver_pos < len(ranks_to_recv): recv_ranks.append(ranks_to_recv[recver_pos]) for dst in recv_ranks: - dst_global = rank_to_global[dst] - p2p_ops += [ - P2POp( - torch.distributed.isend, - w[src], - dst_global, - ) - for w in expert_weights - ] + if is_stateless: + for w in expert_weights: + op = object.__new__(P2POp) + op.op = torch.distributed.isend + op.tensor = w[src] + op.group_peer = dst + p2p_ops.append(op) + else: + dst_global = rank_to_global[dst] + p2p_ops += [ + P2POp( + torch.distributed.isend, + w[src], + dst_global, + ) + for w in expert_weights + ] # 3. Post recvs if recv_count > 0: @@ -321,26 +339,40 @@ def move_to_buffer( src = ranks_to_send[recver_pos // num_dst_per_sender] else: src = ranks_to_send[recver_pos - remainder_start] - src_global = rank_to_global[src] - p2p_ops += [ - P2POp( - torch.distributed.irecv, - b[dst], - src_global, - ) - for b in expert_weights_buffers - ] + if is_stateless: + for b in expert_weights_buffers: + op = object.__new__(P2POp) + op.op = torch.distributed.irecv + op.tensor = b[dst] + op.group_peer = src + p2p_ops.append(op) + else: + src_global = rank_to_global[src] + p2p_ops += [ + P2POp( + torch.distributed.irecv, + b[dst], + src_global, + ) + for b in expert_weights_buffers + ] # 4. Execute the P2P operations. The real communication happens here. if p2p_ops and cuda_stream is not None: with torch.cuda.stream(cuda_stream): + if is_stateless: + ep_group.device_communicator.batch_isend_irecv(p2p_ops) + else: + reqs = batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + elif p2p_ops: + if is_stateless: + ep_group.device_communicator.batch_isend_irecv(p2p_ops) + else: reqs = batch_isend_irecv(p2p_ops) for req in reqs: req.wait() - elif p2p_ops: - reqs = batch_isend_irecv(p2p_ops) - for req in reqs: - req.wait() # wait for the communication to finish return ( is_unchanged, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 9994096bf..9e6b6df08 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from datetime import timedelta from multiprocessing import shared_memory -from typing import Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol from unittest.mock import patch import torch @@ -55,6 +55,9 @@ from vllm.utils.torch_utils import ( direct_register_custom_op, ) +if TYPE_CHECKING: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + @dataclass class GraphCaptureContext: @@ -1157,6 +1160,55 @@ def init_model_parallel_group( ) +def _init_stateless_group( + group_ranks: list[list[int]], + group_name: str, + group_ports: list[list[int]], + host: str, + backend: str, + use_device_communicator: bool = True, +) -> "StatelessGroupCoordinator": + """Create a StatelessGroupCoordinator with the given parameters.""" + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + + world = get_world_group() + return StatelessGroupCoordinator( + group_ranks=group_ranks, + local_rank=world.local_rank, + torch_distributed_backend=backend, + use_device_communicator=use_device_communicator, + group_name=group_name, + host=host, + group_ports=group_ports, + global_rank=world.rank, + global_world_size=world.world_size, + ) + + +def _replace_active_groups( + *, + world: GroupCoordinator | None, + dp: GroupCoordinator | None, + ep: GroupCoordinator | None, + eplb: GroupCoordinator | None, + node_count: int | None, +) -> None: + """Destroy the current DP/EP/WORLD/EPLB groups and replace them. + + Destruction is collective — all ranks in the old groups must call this + function together. Pass all-``None`` to tear down without replacement. + """ + global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT + for group in (_DP, _EP, _WORLD, _EPLB): + if group is not None: + group.destroy() + _WORLD = world + _DP = dp + _EP = ep + _EPLB = eplb + _NODE_COUNT = node_count + + _TP: GroupCoordinator | None = None @@ -1254,6 +1306,39 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def _init_elastic_ep_world( + config, local_rank: int, backend: str, rank: int, world_size: int +) -> None: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + + global _WORLD, _NODE_COUNT + assert _WORLD is None, "world group already initialized" + parallel_config = config.parallel_config + global_rank = parallel_config.data_parallel_rank * world_size + rank + global_world_size = parallel_config.world_size_across_dp + all_ranks = list(range(global_world_size)) + group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)] + if global_rank in all_ranks: + group_ranks = [all_ranks] + group_ports = [parallel_config.get_next_stateless_world_group_port()] + world = StatelessGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=False, + group_name="world", + host=parallel_config.data_parallel_master_ip, + group_ports=group_ports, + global_rank=global_rank, + global_world_size=global_world_size, + ) + assert parallel_config.nnodes_within_dp == 1, ( + "Elastic EP is not supported with multi-node TP/PP" + ) + _NODE_COUNT = _node_count(world.tcp_store_group) + _WORLD = world + + def init_distributed_environment( world_size: int = -1, rank: int = -1, @@ -1273,6 +1358,7 @@ def init_distributed_environment( from vllm.config import get_current_vllm_config_or_none config = get_current_vllm_config_or_none() + enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep if ( config is not None and config.parallel_config.distributed_executor_backend != "external_launcher" @@ -1280,6 +1366,7 @@ def init_distributed_environment( config.parallel_config.nnodes > 1 or config.parallel_config.data_parallel_size > 1 ) + and not enable_elastic_ep ): parallel_config = config.parallel_config # adjust to take into account data parallelism @@ -1333,6 +1420,18 @@ def init_distributed_environment( rank=rank, timeout=timeout, ) + if enable_elastic_ep: + tp_pp_cpu_group = torch.distributed.new_group( + backend="gloo", timeout=timeout + ) + if _node_count(tp_pp_cpu_group) > 1: + # NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip + # to initialize all DP/EP groups, hence all ranks within TP/PP group + # must reside on the same node + raise RuntimeError( + "Elastic EP is not yet supported with multi-node TP/PP" + ) + # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -1341,6 +1440,9 @@ def init_distributed_environment( # setting, where we can use rank as local rank local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank global _WORLD, _NODE_COUNT, _INNER_DP_WORLD + if enable_elastic_ep: + _init_elastic_ep_world(config, local_rank, backend, rank, world_size) + return if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) @@ -1404,16 +1506,33 @@ def initialize_model_parallel( """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - backend = backend or torch.distributed.get_backend(get_world_group().device_group) - data_parallel_size = 1 - from vllm.config import get_current_vllm_config_or_none + from vllm.config import get_current_vllm_config - config = get_current_vllm_config_or_none() - if config is not None: - data_parallel_size = config.parallel_config.data_parallel_size + config = get_current_vllm_config() + data_parallel_size = config.parallel_config.data_parallel_size + enable_elastic_ep = config.parallel_config.enable_elastic_ep + if enable_elastic_ep: + # Use stateless world group for global information + world_size = get_world_group().world_size + rank = get_world_group().rank + backend = backend or "nccl" + tp_pp_pcp_size = ( + tensor_model_parallel_size + * pipeline_model_parallel_size + * prefill_context_model_parallel_size + ) + local_all_ranks = torch.arange(tp_pp_pcp_size).reshape( + pipeline_model_parallel_size, + prefill_context_model_parallel_size, + tensor_model_parallel_size, + ) + else: + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group + ) # the layout order is: ExternalDP x DP x PP x TP # ExternalDP is the data parallel group that is not part of the model, @@ -1437,7 +1556,9 @@ def initialize_model_parallel( assert _TP is None, "tensor model parallel group is already initialized" group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - + if enable_elastic_ep: + group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group( group_ranks, @@ -1456,6 +1577,11 @@ def initialize_model_parallel( # TP group into tp_size//dcp_size DCP groups. group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + group_ranks = local_all_ranks.reshape( + -1, decode_context_model_parallel_size + ).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] _DCP = init_model_parallel_group( group_ranks, get_world_group().local_rank, @@ -1472,6 +1598,13 @@ def initialize_model_parallel( .unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + group_ranks = ( + local_all_ranks.transpose(1, 2) + .reshape(-1, prefill_context_model_parallel_size) + .unbind(0) + ) + group_ranks = [x.tolist() for x in group_ranks] _PCP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="pcp" ) @@ -1483,6 +1616,13 @@ def initialize_model_parallel( all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + group_ranks = ( + local_all_ranks.transpose(0, 2) + .reshape(-1, pipeline_model_parallel_size) + .unbind(0) + ) + group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="pp" ) @@ -1491,14 +1631,27 @@ def initialize_model_parallel( assert _DP is None, "data parallel group is already initialized" group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - _DP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="dp" - ) + if enable_elastic_ep: + parallel_config = config.parallel_config + dp_ports = [ + parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks + ] + _DP = _init_stateless_group( + group_ranks, + "dp", + dp_ports, + parallel_config.data_parallel_master_ip, + backend, + ) + else: + _DP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="dp" + ) global _EP assert _EP is None, "expert parallel group is already initialized" # Don't create EP group for dense models. - if config is None or config.model_config is None or config.model_config.is_moe: + if config.model_config is None or config.model_config.is_moe: group_ranks = ( all_ranks.transpose(1, 2) .reshape( @@ -1510,9 +1663,22 @@ def initialize_model_parallel( .unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] - _EP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="ep" - ) + if enable_elastic_ep: + parallel_config = config.parallel_config + ep_ports = [ + parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks + ] + _EP = _init_stateless_group( + group_ranks, + "ep", + ep_ports, + parallel_config.data_parallel_master_ip, + backend, + ) + else: + _EP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="ep" + ) # Create EPLB group with the same ranks as EP if EPLB is enabled. # This is a separate process group to isolate EPLB communications @@ -1525,10 +1691,25 @@ def initialize_model_parallel( and config.parallel_config is not None and config.parallel_config.enable_eplb ): - # Reuse the same group_ranks from EP - _EPLB = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="eplb" - ) + if enable_elastic_ep: + eplb_ports = [ + parallel_config.get_next_stateless_eplb_group_port() + for _ in group_ranks + ] + _EPLB = _init_stateless_group( + group_ranks, + "eplb", + eplb_ports, + parallel_config.data_parallel_master_ip, + backend, + ) + else: + _EPLB = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="eplb", + ) # If no EP group needed, _EP remains None # If no EPLB group needed, _EPLB remains None @@ -1558,7 +1739,11 @@ def ensure_model_parallel_initialized( or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ - backend = backend or torch.distributed.get_backend(get_world_group().device_group) + world_group = get_world_group() + if hasattr(world_group, "backend"): + backend = backend or world_group.backend + else: + backend = backend or torch.distributed.get_backend(world_group.device_group) if not model_parallel_is_initialized(): initialize_model_parallel( tensor_model_parallel_size, diff --git a/vllm/distributed/stateless_coordinator.py b/vllm/distributed/stateless_coordinator.py new file mode 100644 index 000000000..f2126fdba --- /dev/null +++ b/vllm/distributed/stateless_coordinator.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch +from torch.distributed import Backend, ProcessGroup + +from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator +from vllm.distributed.parallel_state import ( + GroupCoordinator, + TensorMetadata, + _get_unique_name, + _register_group, + _split_tensor_dict, +) +from vllm.distributed.utils import ( + StatelessProcessGroup, + stateless_destroy_torch_distributed_process_group, + stateless_init_torch_distributed_process_group, +) +from vllm.logger import init_logger +from vllm.utils.import_utils import resolve_obj_by_qualname + +logger = init_logger(__name__) + + +class StatelessGroupCoordinator(GroupCoordinator): + """ + A stateless version of the GroupCoordinator class in parallel_state, + It will create CPU, device and TCPStore based communication groups + that are independent of PyTorch's WORLD group. Hence, + communication groups with a different set of participants GPUs + can be created without destroying the existing ones. + """ + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: str | Backend, + use_device_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: str | None = None, + host: str = "127.0.0.1", + group_ports: list[list[int]] | None = None, + global_rank: int = 0, + global_world_size: int = 1, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = global_rank + self.local_rank = local_rank + + self_device_group = None + self_cpu_group = None + self_tcp_store_group = None + + from vllm.platforms import current_platform + + backend = str(torch_distributed_backend) + self.backend = backend + assert group_ports is not None, "group_ports is not provided" + for idx, ranks in enumerate(group_ranks): + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + + ports = group_ports[idx] + device_port = ports[0] + cpu_port = ports[1] + tcp_store_port = ports[2] + + device_group = stateless_init_torch_distributed_process_group( + host=host, + port=device_port, + rank=self.rank_in_group, + world_size=self.world_size, + backend=backend, + group_name=f"{self.unique_name}_device", + ) + cpu_group = stateless_init_torch_distributed_process_group( + host=host, + port=cpu_port, + rank=self.rank_in_group, + world_size=self.world_size, + backend="gloo", + group_name=f"{self.unique_name}_cpu", + ) + tcp_store_group = StatelessProcessGroup.create( + host=host, + port=tcp_store_port, + rank=self.rank_in_group, + world_size=self.world_size, + ) + + self_device_group = device_group + self_cpu_group = cpu_group + self_tcp_store_group = tcp_store_group + + assert self_cpu_group is not None + assert self_device_group is not None + assert self_tcp_store_group is not None + + self.cpu_group = self_cpu_group + self.device_group = self_device_group + self.tcp_store_group = self_tcp_store_group + + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") + elif current_platform.is_xpu(): + self.device = torch.device(f"xpu:{local_rank}") + elif current_platform.is_out_of_tree(): + self.device = torch.device(f"{current_platform.device_name}:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_device_communicator = use_device_communicator + self.device_communicator = None + if use_device_communicator and self.world_size > 1: + device_comm_cls = resolve_obj_by_qualname( + current_platform.get_device_communicator_cls() + ) + assert device_comm_cls == CudaCommunicator + self.device_communicator = CudaCommunicator( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + global_ranks=self.ranks, + global_world_size=global_world_size, + tcp_store_group=self.tcp_store_group, + ) + + self.mq_broadcaster = None + + self.use_custom_op_call = ( + current_platform.is_cuda_alike() or current_platform.is_tpu() + ) + self.use_cpu_custom_send_recv = False + + def destroy(self): + if self.device_communicator: + self.device_communicator.destroy() + if self.device_group: + stateless_destroy_torch_distributed_process_group(self.device_group) + if self.cpu_group: + stateless_destroy_torch_distributed_process_group(self.cpu_group) + + def size(self) -> int: + """Return the world size of this group.""" + return self.world_size + + def broadcast(self, input_: torch.Tensor, src: int = 0): + if self.world_size == 1: + return input_ + + if self.device_communicator and input_.is_cuda: + return self.device_communicator.broadcast(input_, src) + else: + return self.tcp_store_group.broadcast(input_, src) + + def broadcast_object(self, obj=None, src: int = 0): + if self.world_size == 1: + return obj + return self.tcp_store_group.broadcast_obj(obj, src) + + def broadcast_object_list( + self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None + ): + assert src < self.world_size + + if self.world_size == 1: + return obj_list + + if self.rank_in_group == src: + for obj in obj_list: + self.tcp_store_group.broadcast_obj(obj, src) + else: + for i in range(len(obj_list)): + obj_list[i] = self.tcp_store_group.broadcast_obj(None, src) + + return obj_list + + def broadcast_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any] | None = None, + src: int = 0, + group: ProcessGroup | None = None, + metadata_group: ProcessGroup | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + if self.world_size == 1: + return tensor_dict + + if self.rank_in_group == src: + assert isinstance(tensor_dict, dict), ( + f"Expecting a dictionary, got {type(tensor_dict)}" + ) + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + else: + metadata_list = None + tensor_list = [] + + recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj( + metadata_list, src + ) + + if self.rank_in_group != src: + tensor_dict = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + tensor_list.append(tensor) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + + for tensor in tensor_list: + if tensor.numel() == 0: + continue + if self.device_communicator and tensor.is_cuda: + tensor.copy_(self.device_communicator.broadcast(tensor, src)) + else: + tensor.copy_(self.tcp_store_group.broadcast(tensor, src)) + + return tensor_dict + + def send_object(self, obj, dst: int) -> None: + assert dst < self.world_size + assert dst != self.rank_in_group + self.tcp_store_group.send_obj(obj, dst) + + def recv_object(self, src: int): + assert src < self.world_size + assert src != self.rank_in_group + return self.tcp_store_group.recv_obj(src) + + def send_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any], + dst: int | None = None, + all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: dict[str, bool] | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + if self.world_size == 1: + return tensor_dict + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size + + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + self.tcp_store_group.send_obj(metadata_list, dst) + + for tensor in tensor_list: + if tensor.numel() == 0: + continue + if self.device_communicator and tensor.is_cuda: + self.device_communicator.send(tensor, dst) + else: + self.tcp_store_group.send(tensor, dst) + + return None + + def recv_tensor_dict( + self, + src: int | None = None, + all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: dict[str, bool] | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + if self.world_size == 1: + return None + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size + + recv_metadata_list = self.tcp_store_group.recv_obj(src) + tensor_dict = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() > 0: + if self.device_communicator and tensor.is_cuda: + tensor = self.device_communicator.recv( + tensor.size(), tensor.dtype, src + ) + else: + tensor = self.tcp_store_group.recv(tensor, src) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + self.tcp_store_group.barrier() + + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> torch.Tensor | None: + if self.world_size == 1: + return input_ + + if self.device_communicator is None: + raise ValueError("No device communicator found") + + if self.rank_in_group == dst: + gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)] + gathered_list[self.rank_in_group] = input_ + for src_rank in range(self.world_size): + if src_rank != self.rank_in_group: + gathered_list[src_rank] = self.device_communicator.recv( + input_.size(), input_.dtype, src_rank + ) + return torch.cat(gathered_list, dim=dim) + else: + self.device_communicator.send(input_, dst) + return None diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 17375259e..102f2f727 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -18,7 +18,7 @@ from datetime import timedelta from typing import Any import torch -from torch.distributed import ProcessGroup, TCPStore +from torch.distributed import ProcessGroup, Store, TCPStore from torch.distributed.distributed_c10d import ( Backend, PrefixStore, @@ -228,6 +228,55 @@ class StatelessProcessGroup: gathered_objs.append(recv_obj) return gathered_objs + def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + """Broadcast a tensor from source rank to all other ranks.""" + if self.rank == src: + tensor_bytes = pickle.dumps(tensor) + self.expire_data() + key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}" + self.store.set(key, tensor_bytes) + self.broadcast_send_counter += 1 + self.entries.append((key, time.time())) + return tensor + else: + key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}" + tensor = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return tensor + + def send(self, tensor: torch.Tensor, dst: int): + """Send a tensor to a destination rank.""" + self.expire_data() + key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(tensor)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) + + def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + """Receive a tensor from a source rank.""" + key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}" + received = pickle.loads(self.store.get(key)) + self.recv_src_counter[src] += 1 + tensor.copy_(received) + return tensor + + def all_reduce( + self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM + ) -> torch.Tensor: + """All-reduce a tensor across all ranks.""" + tensors = self.all_gather_obj(tensor) + result = tensors[0].clone() + for t in tensors[1:]: + if op == torch.distributed.ReduceOp.SUM: + result.add_(t) + elif op == torch.distributed.ReduceOp.PRODUCT: + result.mul_(t) + elif op == torch.distributed.ReduceOp.MAX: + result = torch.maximum(result, t) + elif op == torch.distributed.ReduceOp.MIN: + result = torch.minimum(result, t) + return result + def barrier(self, timeout: float = 30.0): """A robust barrier to synchronize all ranks. @@ -448,8 +497,14 @@ def init_gloo_process_group( def stateless_init_torch_distributed_process_group( - host: str, port: int, rank: int, world_size: int, backend: str -) -> ProcessGroup: + host: str, + port: int, + rank: int, + world_size: int, + backend: str, + group_name: str | None = None, + return_store: bool = False, +) -> ProcessGroup | tuple[ProcessGroup, Store]: """ A replacement for `torch.distributed.init_process_group` that does not pollute the global state. The created ProcessGroup object can be used for @@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group( # Use a PrefixStore to avoid accidental overrides of keys used by # different systems (e.g. RPC) in case the store is multi-tenant. prefix_store = PrefixStore(init_method, store) - try: + + if backend == "gloo": + pg = init_gloo_process_group( + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout, + ) + else: from vllm.platforms import current_platform - return current_platform.stateless_init_device_torch_dist_pg( + pg = current_platform.stateless_init_device_torch_dist_pg( backend=backend, prefix_store=prefix_store, group_rank=group_rank, group_size=group_size, timeout=timeout, ) - except NotImplementedError: - # If platform doesn't implement stateless_init_device_torch_dist_pg, it - # will raise a NotImplementedError. In this case, we fall back to gloo. - return init_gloo_process_group( - prefix_store=prefix_store, - group_rank=group_rank, - group_size=group_size, - timeout=timeout, - ) + + if group_name is not None: + from torch._C._distributed_c10d import _register_process_group + + pg._set_group_name(group_name) + _register_process_group(group_name, pg) + + if return_store: + return pg, store + else: + return pg def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2e9cd6634..64b505a1d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -419,6 +419,7 @@ class EngineArgs: enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel moe_backend: MoEBackend = KernelConfig.moe_backend all2all_backend: All2AllBackend = ParallelConfig.all2all_backend + enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep enable_dbo: bool = ParallelConfig.enable_dbo ubatch_size: int = ParallelConfig.ubatch_size dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold @@ -896,6 +897,9 @@ class EngineArgs: "--ubatch-size", **parallel_kwargs["ubatch_size"], ) + parallel_group.add_argument( + "--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"] + ) parallel_group.add_argument( "--dbo-decode-token-threshold", **parallel_kwargs["dbo_decode_token_threshold"], @@ -1698,6 +1702,7 @@ class EngineArgs: is_moe_model=model_config.is_moe, enable_expert_parallel=self.enable_expert_parallel, all2all_backend=self.all2all_backend, + enable_elastic_ep=self.enable_elastic_ep, enable_dbo=self.enable_dbo, ubatch_size=self.ubatch_size, dbo_decode_token_threshold=self.dbo_decode_token_threshold, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index c12cc7ff2..9e3988b15 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -246,8 +246,12 @@ def run_multi_api_server(args: argparse.Namespace): api_server_manager: APIServerProcessManager | None = None + from vllm.v1.engine.utils import get_engine_zmq_addresses + + addresses = get_engine_zmq_addresses(vllm_config, num_api_servers) + with launch_core_engines( - vllm_config, executor_class, log_stats, num_api_servers + vllm_config, executor_class, log_stats, addresses, num_api_servers ) as (local_engine_manager, coordinator, addresses): # Construct common args for the APIServerProcessManager up-front. api_server_manager_kwargs = dict( diff --git a/vllm/envs.py b/vllm/envs.py index 07d9f81ea..864ea6649 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -243,6 +243,8 @@ if TYPE_CHECKING: VLLM_LORA_DISABLE_PDL: bool = False VLLM_ENABLE_CUDA_COMPATIBILITY: bool = False VLLM_CUDA_COMPATIBILITY_PATH: str | None = None + VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False + VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False def get_default_cache_root(): @@ -1617,6 +1619,16 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_CUDA_COMPATIBILITY_PATH": lambda: os.environ.get( "VLLM_CUDA_COMPATIBILITY_PATH", None ), + # Whether it is a scale up launch engine for elastic EP, + # Should only be set by EngineCoreClient. + "VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": lambda: bool( + int(os.getenv("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH", "0")) + ), + # Whether to wait for all requests to drain before sending the + # scaling command in elastic EP. + "VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool( + int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0")) + ), } diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a7dee7004..620047709 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -627,6 +627,7 @@ class FusedMoE(CustomOp): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) + self.base_quant_method = self.quant_method # Disable shared expert overlap if: # - we are using eplb with non-default backend, because of correctness issues @@ -683,7 +684,7 @@ class FusedMoE(CustomOp): # routing_tables only needed for round-robin expert placement with # DeepEP all2all backend. routing_tables = self._maybe_init_expert_routing_tables() - prepare_finalize = self.quant_method.maybe_make_prepare_finalize( + prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize( routing_tables=routing_tables ) if prepare_finalize is not None: @@ -693,7 +694,7 @@ class FusedMoE(CustomOp): self._replace_quant_method( FusedMoEModularMethod.make( self, - self.quant_method, + self.base_quant_method, prepare_finalize, self.shared_experts, inplace=not self.moe_config.disable_inplace, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index d3312fe15..af627964f 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -6,10 +6,13 @@ pynvml. However, it should not initialize cuda context. import os from collections.abc import Callable +from datetime import timedelta from functools import cache, wraps from typing import TYPE_CHECKING, TypeVar import torch +from torch.distributed import PrefixStore, ProcessGroup +from torch.distributed.distributed_c10d import is_nccl_available from typing_extensions import ParamSpec # import custom ops, trigger op registration @@ -482,6 +485,37 @@ class CudaPlatformBase(Platform): def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" + @classmethod + def stateless_init_device_torch_dist_pg( + cls, + backend: str, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, + ) -> ProcessGroup: + assert is_nccl_available() + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, group_size, backend_options + ) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + return pg + @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3808ecc6e..e867ebbd6 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -2,10 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from datetime import timedelta from functools import cache, lru_cache, wraps from typing import TYPE_CHECKING import torch +from torch.distributed import PrefixStore, ProcessGroup +from torch.distributed.distributed_c10d import is_nccl_available import vllm.envs as envs from vllm.logger import init_logger @@ -656,6 +659,37 @@ class RocmPlatform(Platform): def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" + @classmethod + def stateless_init_device_torch_dist_pg( + cls, + backend: str, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, + ) -> ProcessGroup: + assert is_nccl_available() + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, group_size, backend_options + ) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + return pg + @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 1dd9f64f8..19413ddb4 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -29,6 +29,15 @@ PauseMode = Literal["abort", "wait", "keep"] # so form part of the external API. FINISH_REASON_STRINGS = ("stop", "length", "abort", "error") +EEP_NOTIFICATION_CALL_ID = -1 + + +class EEPNotificationType(enum.Enum): + NEW_CORE_ENGINES_INIT_READY = "NEW_CORE_ENGINES_INIT_READY" + NEW_CORE_ENGINES_WEIGHTS_INIT_READY = "NEW_CORE_ENGINES_WEIGHTS_INIT_READY" + RECONFIGURE_FINISHED = "RECONFIGURE_FINISHED" + SHUTDOWN_COMPLETE = "SHUTDOWN_COMPLETE" + class FinishReason(enum.IntEnum): """ @@ -235,6 +244,11 @@ class ReconfigureDistributedRequest(msgspec.Struct): new_data_parallel_rank_local: int new_data_parallel_master_ip: str new_data_parallel_master_port: int + new_data_parallel_master_port_list: list[int] + new_stateless_world_group_port_list: list[list[int]] + new_stateless_dp_group_port_list: list[list[int]] + new_stateless_ep_group_port_list: list[list[int]] + new_stateless_eplb_group_port_list: list[list[int]] class ReconfigureRankType(enum.IntEnum): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index d86e1b43d..f172d6dda 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -20,6 +20,7 @@ from vllm.distributed.weight_transfer.base import ( ) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient, StreamingInput +from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -647,7 +648,11 @@ class AsyncLLM(EngineClient): engine_core = self.engine_core output_processor = self.output_processor log_stats = self.log_stats - logger_manager = self.logger_manager + # We use a mutable list for logger_manager so that it can be updated + # during elastic EP scaling (see scale_elastic_ep) without creating + # a circular reference via self. + self._logger_ref = [self.logger_manager] + logger_ref = self._logger_ref renderer = self.renderer chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE @@ -691,8 +696,8 @@ class AsyncLLM(EngineClient): # 4) Logging. # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. - if logger_manager: - logger_manager.record( + if logger_ref[0]: + logger_ref[0].record( engine_idx=outputs.engine_index, scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, @@ -976,17 +981,13 @@ class AsyncLLM(EngineClient): new_data_parallel_size, ) return - logger.info( - "Waiting for requests to drain before scaling up to %s engines...", - new_data_parallel_size, - ) - await self.wait_for_requests_to_drain(drain_timeout) - logger.info( - "Requests have been drained, proceeding with scale to %s engines", - new_data_parallel_size, - ) - await self.engine_core.scale_elastic_ep(new_data_parallel_size) - self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size + + if envs.VLLM_ELASTIC_EP_DRAIN_REQUESTS: + logger.info( + "VLLM_ELASTIC_EP_DRAIN_REQUESTS is set, " + "waiting for requests to drain before scaling" + ) + await self.wait_for_requests_to_drain(drain_timeout) # recreate stat loggers if new_data_parallel_size > old_data_parallel_size and self.log_stats: @@ -999,6 +1000,18 @@ class AsyncLLM(EngineClient): engine_idxs=list(range(new_data_parallel_size)), custom_stat_loggers=None, ) + # Update the mutable ref so output_handler picks up the + # new logger without creating a circular reference via self. + if hasattr(self, "_logger_ref"): + self._logger_ref[0] = self.logger_manager + self.logger_manager.log_engine_initialized() + + set_scaling_elastic_ep(True) + try: + await self.engine_core.scale_elastic_ep(new_data_parallel_size) + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size + finally: + set_scaling_elastic_ep(False) @property def is_running(self) -> bool: diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 672d536a5..44a346350 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -71,6 +71,9 @@ class DPCoordinator: ) local_only_eng = dp_size == parallel_config.data_parallel_size_local + # NOTE(yongji): handling scaling from intra-node to inter-node + if parallel_config.enable_elastic_ep: + local_only_eng = False back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_output_address = get_engine_client_zmq_addr(local_only_eng, host) @@ -201,6 +204,7 @@ class DPCoordinatorProc: poller = zmq.Poller() poller.register(publish_front, zmq.POLLIN) + poller.register(publish_back, zmq.POLLIN) poller.register(output_back, zmq.POLLIN) last_publish_time = 0 while True: @@ -231,6 +235,22 @@ class DPCoordinatorProc: events = dict(events) wave_state_changed = False + if publish_back in events: + buffer = publish_back.recv() + if buffer == b"\x01": + # NOTE(yongji): newly started engine subscribed + # We need to send READY message here instead of receiving + # SCALE_ELASTIC_EP notification from engine core client + # as SCALE_ELASTIC_EP is only sent when + # new engines finished initialization. + # Subscription message, on the other hand, is sent + # by each engine during initialization + publish_back.send(b"READY") + else: + logger.error( + "DP Coordinator receives unexpected message from engines" + ) + if publish_front in events: buffer = publish_front.recv() if buffer in (b"\x01", b"\x00"): @@ -259,7 +279,6 @@ class DPCoordinatorProc: # current_wave # we note that 0 is the wave number for the new # engine - engines_running = False logger.info( "DPCoordinator scaled up from %s to %s engines", current_count, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 39515cab7..4de3e4ea7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -17,6 +17,7 @@ from typing import Any, TypeVar, cast import msgspec import zmq +import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.envs import enable_envs_cache @@ -44,6 +45,8 @@ from vllm.v1.core.kv_cache_utils import ( from vllm.v1.core.sched.interface import PauseState, SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import ( + EEP_NOTIFICATION_CALL_ID, + EEPNotificationType, EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, @@ -110,6 +113,9 @@ class EngineCore: self.available_gpu_memory_for_kv_cache = -1 + if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: + self._eep_scale_up_before_kv_init() + # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( vllm_config @@ -233,12 +239,10 @@ class EngineCore: has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs) if has_kv_cache: - if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": - dp_group = getattr(self, "dp_group", None) - assert dp_group is not None - self.available_gpu_memory_for_kv_cache = ( - ParallelConfig.sync_kv_cache_memory_size(dp_group, -1) - ) + if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: + # NOTE(yongji): should already be set + # during _eep_scale_up_before_kv_init + assert self.available_gpu_memory_for_kv_cache > 0 available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len( kv_cache_specs ) @@ -752,11 +756,22 @@ class EngineCore: self.structured_output_manager.grammar_init(req) return req, request.current_wave + def _eep_scale_up_before_kv_init(self): + raise NotImplementedError + + def _eep_send_engine_core_notification( + self, + notification_type: EEPNotificationType, + vllm_config: VllmConfig | None = None, + ): + raise NotImplementedError + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD" + addresses: EngineZmqAddresses @instrument(span_name="EngineCoreProc init") def __init__( @@ -807,6 +822,13 @@ class EngineCoreProc(EngineCore): # and "hybrid" LB modes. self.publish_dp_lb_stats = internal_dp_balancing + self.addresses = addresses + self.process_input_queue_block = True + if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: + self._eep_send_engine_core_notification( + EEPNotificationType.NEW_CORE_ENGINES_INIT_READY, + vllm_config=vllm_config, + ) self._init_data_parallel(vllm_config) super().__init__( @@ -1119,8 +1141,14 @@ class EngineCoreProc(EngineCore): if logger.isEnabledFor(DEBUG): logger.debug("EngineCore waiting for work.") waited = True - req = self.input_queue.get() - self._handle_client_request(*req) + block = self.process_input_queue_block + try: + req = self.input_queue.get(block=block) + self._handle_client_request(*req) + except queue.Empty: + break + if not block: + break if waited: logger.debug("EngineCore loop active.") @@ -1290,6 +1318,11 @@ class EngineCoreProc(EngineCore): for input_socket, _ in poller.poll(): # (RequestType, RequestData) type_frame, *data_frames = input_socket.recv_multipart(copy=False) + # NOTE(yongji): ignore READY message sent by DP coordinator + # that is used to notify newly started engines + if type_frame.buffer == b"READY": + assert input_socket == coord_socket + continue request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. @@ -1488,6 +1521,10 @@ class DPEngineCoreProc(EngineCoreProc): self.current_wave = 0 self.last_counts = (0, 0) + from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState + + self.eep_scaling_state: ElasticEPScalingState | None = None + # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank super().__init__( @@ -1511,7 +1548,9 @@ class DPEngineCoreProc(EngineCoreProc): assert 0 <= local_dp_rank <= dp_rank < dp_size self.dp_rank = dp_rank - self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + self.dp_group, self.dp_store = ( + vllm_config.parallel_config.stateless_init_dp_group(return_store=True) + ) def shutdown(self): super().shutdown() @@ -1574,7 +1613,12 @@ class DPEngineCoreProc(EngineCoreProc): # 1) Poll the input queue until there is work to do. self._process_input_queue() - # 2) Step the engine core. + if self.eep_scaling_state is not None: + _ = self.eep_scaling_state.progress() + if self.eep_scaling_state.is_complete(): + self.process_input_queue_block = True + self.eep_scaling_state = None + executed = self._process_engine_step() self._maybe_publish_request_counts() @@ -1624,54 +1668,129 @@ class DPEngineCoreProc(EngineCoreProc): def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest ) -> None: - stateless_destroy_torch_distributed_process_group(self.dp_group) - self.shutdown() + from copy import deepcopy - parallel_config = self.vllm_config.parallel_config - old_dp_size = parallel_config.data_parallel_size - parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size - if reconfig_request.new_data_parallel_rank != -1: - parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank - # local rank specifies device visibility, it should not be changed - assert ( - reconfig_request.new_data_parallel_rank_local - == ReconfigureRankType.KEEP_CURRENT_RANK - ) - parallel_config.data_parallel_master_ip = ( - reconfig_request.new_data_parallel_master_ip - ) - parallel_config.data_parallel_master_port = ( - reconfig_request.new_data_parallel_master_port - ) - if reconfig_request.new_data_parallel_rank != -2: - self.dp_rank = parallel_config.data_parallel_rank - self.dp_group = parallel_config.stateless_init_dp_group() - reconfig_request.new_data_parallel_master_port = ( - parallel_config.data_parallel_master_port - ) + from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState - self.model_executor.reinitialize_distributed(reconfig_request) - if reconfig_request.new_data_parallel_size > old_dp_size: - assert self.available_gpu_memory_for_kv_cache > 0 - # pass available_gpu_memory_for_kv_cache from existing - # engine-cores to new engine-cores so they can directly - # use it in _initialize_kv_caches() rather than profiling. - ParallelConfig.sync_kv_cache_memory_size( - self.dp_group, self.available_gpu_memory_for_kv_cache - ) - # NOTE(yongji): newly joined workers require dummy_run even - # CUDA graph is not used - self.model_executor.collective_rpc("compile_or_warm_up_model") + new_parallel_config = deepcopy(self.vllm_config.parallel_config) + old_dp_size = new_parallel_config.data_parallel_size + new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size if ( reconfig_request.new_data_parallel_rank - == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + != ReconfigureRankType.KEEP_CURRENT_RANK ): - self.shutdown() - logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) - else: - logger.info( - "Distributed environment reinitialized for DP rank %s", self.dp_rank + new_parallel_config.data_parallel_rank = ( + reconfig_request.new_data_parallel_rank ) + new_parallel_config.data_parallel_master_ip = ( + reconfig_request.new_data_parallel_master_ip + ) + new_parallel_config.data_parallel_master_port = ( + reconfig_request.new_data_parallel_master_port + ) + new_parallel_config._data_parallel_master_port_list = ( + reconfig_request.new_data_parallel_master_port_list + ) + + is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size + is_shutdown = ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ) + + self.eep_scaling_state = ElasticEPScalingState( + model_executor=self.model_executor, + engine_core=self, + vllm_config=self.vllm_config, + new_parallel_config=new_parallel_config, + worker_type="removing" if is_shutdown else "existing", + scale_type="scale_down" if is_scale_down else "scale_up", + reconfig_request=reconfig_request, + ) + self.process_input_queue_block = False + logger.info( + "[Elastic EP] Received reconfiguration request and starting scaling up/down" + ) + + def _eep_send_engine_core_notification( + self, + notification_type: EEPNotificationType, + vllm_config: VllmConfig | None = None, + ): + """ + Send notifications to EngineCoreClient, which can then forward + the notifications to other engine core processes. It is used for: + 1) In scale up: new core engines to notify exisiting core engines + that they are ready; + 2) In scale down: removing core engines to notify EngineCoreClient + so EngineCoreClient can release their ray placement groups; + 3) Both scale up/down: to notify EngineCoreClient that exisiting + core engines have already switched to the new parallel setup. + """ + if vllm_config is None: + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + else: + dp_rank = vllm_config.parallel_config.data_parallel_rank + notification_data = (notification_type.value, dp_rank) + outputs = EngineCoreOutputs( + utility_output=UtilityOutput( + call_id=EEP_NOTIFICATION_CALL_ID, + result=UtilityResult(notification_data), + ) + ) + outputs.engine_index = self.engine_index + + if hasattr(self, "output_thread") and self.output_thread.is_alive(): + self.output_queue.put_nowait((0, outputs)) + else: + encoder = MsgpackEncoder() + with ( + zmq.Context() as ctx, + make_zmq_socket( + ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000 + ) as socket, + ): + socket.send_multipart(encoder.encode(outputs)) + + def eep_handle_engine_core_notification( + self, notification_type: str | EEPNotificationType + ): + """ + Handle notification received from EngineCoreClient + (forwarded from new core engines). + """ + assert self.eep_scaling_state is not None + if isinstance(notification_type, str): + notification_type = EEPNotificationType(notification_type) + self.eep_scaling_state.handle_notification(notification_type) + + def _eep_scale_up_before_kv_init(self): + from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState + + self.eep_scaling_state = ElasticEPScalingState( + model_executor=self.model_executor, + engine_core=self, + vllm_config=self.vllm_config, + new_parallel_config=self.vllm_config.parallel_config, + worker_type="new", + scale_type="scale_up", + reconfig_request=None, + ) + self.model_executor.collective_rpc("init_device") + self.model_executor.collective_rpc("load_model") + self._eep_send_engine_core_notification( + EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY + ) + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("receive_weights",) + ) + self.available_gpu_memory_for_kv_cache = ( + ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1) + ) + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("prepare_new_worker",) + ) + self.process_input_queue_block = False class EngineCoreActorMixin: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 777dea5ae..e19b31396 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -28,11 +28,12 @@ from vllm.tracing import instrument from vllm.utils.async_utils import in_loop from vllm.utils.network_utils import ( close_sockets, - get_open_port, get_open_zmq_inproc_path, make_zmq_socket, ) from vllm.v1.engine import ( + EEP_NOTIFICATION_CALL_ID, + EEPNotificationType, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, @@ -47,6 +48,7 @@ from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.utils import ( CoreEngineActorManager, CoreEngineProcManager, + get_engine_zmq_addresses, launch_core_engines, ) from vllm.v1.executor import Executor @@ -445,6 +447,63 @@ class BackgroundResources: raise EngineDeadError() +@dataclass +class ElasticScalingCache: + existing_core_engines: list[EngineIdentity] + num_new_core_engines: int + pending_notifications: dict[EEPNotificationType, set[int]] + + +def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int): + """ + Allocate stateless group ports for elastic EP. + """ + from vllm.utils.network_utils import get_open_ports_list + + assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled" + world_size = parallel_config.world_size + new_world_size_across_dp = world_size * new_data_parallel_size + num_world_groups = 1 + num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size) + num_ep_groups = max( + 1, + new_world_size_across_dp + // (new_data_parallel_size * parallel_config.tensor_parallel_size), + ) + num_eplb_groups = num_ep_groups + total_ports_needed = ( + num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups + ) * 3 + 5 + all_ports = get_open_ports_list(total_ports_needed) + new_data_parallel_master_port_list = all_ports[-5:] + all_ports = all_ports[:-5] + new_stateless_world_group_port_list = [ + all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3) + ] + start_idx = num_world_groups * 3 + new_stateless_dp_group_port_list = [ + all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3) + ] + start_idx += num_dp_groups * 3 + new_stateless_ep_group_port_list = [ + all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3) + ] + start_idx += num_ep_groups * 3 + new_stateless_eplb_group_port_list = [ + all_ports[i : i + 3] + for i in range(start_idx, start_idx + num_eplb_groups * 3, 3) + ] + + parallel_config._stateless_world_group_port_list = ( + new_stateless_world_group_port_list + ) + parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list + parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list + parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list + parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop() + parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list + + class MPClient(EngineCoreClient): """ MPClient: base client for multi-proc EngineCore. @@ -491,32 +550,37 @@ class MPClient(EngineCoreClient): input_address = client_addresses["input_address"] output_address = client_addresses["output_address"] self.stats_update_address = client_addresses.get("stats_update_address") + self.input_socket = self.resources.input_socket = make_zmq_socket( + self.ctx, input_address, zmq.ROUTER, bind=True + ) + self.resources.output_socket = make_zmq_socket( + self.ctx, output_address, zmq.PULL + ) else: # Engines are managed by this client. - with launch_core_engines(vllm_config, executor_class, log_stats) as ( - engine_manager, - coordinator, + addresses = get_engine_zmq_addresses(vllm_config) + self.input_socket = self.resources.input_socket = make_zmq_socket( + self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True + ) + self.resources.output_socket = make_zmq_socket( + self.ctx, addresses.outputs[0], zmq.PULL + ) + + with launch_core_engines( + vllm_config, + executor_class, + log_stats, addresses, - ): + ) as (engine_manager, coordinator, addresses): self.resources.coordinator = coordinator self.resources.engine_manager = engine_manager - (input_address,) = addresses.inputs - (output_address,) = addresses.outputs self.stats_update_address = addresses.frontend_stats_publish_address if coordinator is not None: assert self.stats_update_address == ( coordinator.get_stats_publish_address() ) - # Create input and output sockets. - self.input_socket = self.resources.input_socket = make_zmq_socket( - self.ctx, input_address, zmq.ROUTER, bind=True - ) - self.resources.output_socket = make_zmq_socket( - self.ctx, output_address, zmq.PULL - ) - parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_index @@ -877,6 +941,10 @@ class AsyncMPClient(MPClient): output_socket = resources.output_socket assert output_socket is not None + notification_callback_handler: ( + Callable[[AsyncMPClient, Sequence[Any]], Any] | None + ) = getattr(self.__class__, "eep_process_engine_core_notification", None) + async def process_outputs_socket(): try: while True: @@ -884,7 +952,26 @@ class AsyncMPClient(MPClient): resources.validate_alive(frames) outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: - _process_utility_output(outputs.utility_output, utility_results) + if ( + outputs.utility_output.call_id == EEP_NOTIFICATION_CALL_ID + and notification_callback_handler is not None + ): + assert _self_ref is not None + _self = _self_ref() + if not _self: + return + if outputs.utility_output.result is None: + continue + notification_data = outputs.utility_output.result.result + assert isinstance(notification_data, Sequence) + assert len(notification_data) == 2 + asyncio.create_task( + notification_callback_handler(_self, notification_data) + ) + else: + _process_utility_output( + outputs.utility_output, utility_results + ) continue if output_handler is not None: @@ -1081,6 +1168,8 @@ class DPAsyncMPClient(AsyncMPClient): # Used only by DPLBAsyncMPClient subclass. self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines] + self.eep_scaling_cache: ElasticScalingCache | None = None + self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_send_socket = self.resources.first_req_send_socket = ( make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True) @@ -1101,12 +1190,6 @@ class DPAsyncMPClient(AsyncMPClient): assert self.stats_update_address is not None stats_addr: str = self.stats_update_address assert len(self.engine_ranks_managed) > 0 - # NOTE: running and waiting counts are all global from - # the Coordinator include all global EngineCores. This - # slice includes just the cores managed by this client. - count_slice = slice( - self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1 - ) async def run_engine_stats_update_task(): with ( @@ -1145,6 +1228,29 @@ class DPAsyncMPClient(AsyncMPClient): ): # Extract new engine count from the decoded message new_engine_count = decoded[1] + # Update engine_ranks_managed and count_slice + parallel_config = self.vllm_config.parallel_config + dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank + assert dp_rank == 0 + assert dp_size == new_engine_count + assert not ( + parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb + ) + num_ranks = dp_size + self.engine_ranks_managed = list( + range(dp_rank, dp_rank + num_ranks) + ) + if len(self.lb_engines) < new_engine_count: + self.lb_engines = self.lb_engines + [ + [0, 0] + for _ in range( + new_engine_count - len(self.lb_engines) + ) + ] + else: + self.lb_engines = self.lb_engines[:new_engine_count] # Send scale up notification to coordinator scale_msg = msgspec.msgpack.encode( ("SCALE_ELASTIC_EP", new_engine_count) @@ -1178,6 +1284,11 @@ class DPAsyncMPClient(AsyncMPClient): self.current_wave = wave self.engines_running = running if counts is not None: + # Running and waiting counts are global from the + # Coordinator including all EngineCores. Slice to get + # just the cores managed by this client. + ranks = self.engine_ranks_managed + count_slice = slice(ranks[0], ranks[-1] + 1) sliced_counts = counts[count_slice] self.lb_engines = sliced_counts logger.debug( @@ -1287,6 +1398,67 @@ class DPLBAsyncMPClient(DPAsyncMPClient): for req_id in outputs.finished_requests: self.reqs_in_flight.pop(req_id, None) + @staticmethod + async def eep_process_engine_core_notification( + self: "DPLBAsyncMPClient", notification_data: tuple[str, int] + ): + cache = self.eep_scaling_cache + notification_type_str, dp_rank = notification_data + try: + notification_type = EEPNotificationType(notification_type_str) + except ValueError as e: + raise ValueError( + f"Unknown EEP notification type: {notification_type_str}" + ) from e + + if notification_type == EEPNotificationType.RECONFIGURE_FINISHED: + from vllm.v1.engine import UtilityResult + + # NOTE(yongji): process a dummy UtilityOutput to resolve the future + # awaited in _eep_wait_for_setup_switch_complete(), signaling that + # all engine cores have completed reconfiguration. + dummy_output = UtilityOutput( + call_id=EEP_NOTIFICATION_CALL_ID, result=UtilityResult(None) + ) + _process_utility_output(dummy_output, self.utility_results) + return + assert cache is not None + if notification_type not in cache.pending_notifications: + cache.pending_notifications[notification_type] = set() + if dp_rank in cache.pending_notifications[notification_type]: + raise ValueError( + f"Duplicate notification {notification_type} from dp_rank {dp_rank}" + ) + cache.pending_notifications[notification_type].add(dp_rank) + if len(cache.pending_notifications[notification_type]) >= abs( + cache.num_new_core_engines + ): + if notification_type == EEPNotificationType.SHUTDOWN_COMPLETE: + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) + assert cache.num_new_core_engines < 0 + old_dp_size = len(cache.existing_core_engines) + new_dp_size = old_dp_size + cache.num_new_core_engines + self.resources.engine_manager.scale_down_elastic_ep( + old_dp_size, new_dp_size + ) + else: + await asyncio.gather( + *[ + self._call_utility_async( + "eep_handle_engine_core_notification", + notification_type, + engine=engine, + ) + for engine in cache.existing_core_engines + ] + ) + cache.pending_notifications[notification_type] = set() + if notification_type in [ + EEPNotificationType.SHUTDOWN_COMPLETE, + EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY, + ]: + self.eep_scaling_cache = None + async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids or self.resources.engine_dead: return @@ -1333,6 +1505,20 @@ class DPLBAsyncMPClient(DPAsyncMPClient): cur_data_parallel_size, new_data_parallel_size ) + async def _eep_wait_for_setup_switch_complete(self) -> None: + """ + Wait for core engines to switch to the new setup. + + In eep_process_engine_core_notification(), a dummy UtilityOutput with + EEP_NOTIFICATION_CALL_ID will be set when RECONFIGURE_FINISHED + notification is received from engine 0. We create a future with + that call_id and wait for it to be resolved. + """ + future = asyncio.get_running_loop().create_future() + self.utility_results[EEP_NOTIFICATION_CALL_ID] = future + self._ensure_output_queue_task() + await future + async def _scale_up_elastic_ep( self, cur_data_parallel_size: int, new_data_parallel_size: int ) -> None: @@ -1340,38 +1526,57 @@ class DPLBAsyncMPClient(DPAsyncMPClient): and reconfiguring existing ones.""" cur_data_parallel_size = len(self.core_engines) - # Phase 1: Send reconfigure messages to all existing engines and wait - # for them to be sent + self.eep_scaling_cache = ElasticScalingCache( + existing_core_engines=self.core_engines.copy(), + num_new_core_engines=new_data_parallel_size - cur_data_parallel_size, + pending_notifications=dict(), + ) + + parallel_config = self.vllm_config.parallel_config + allocate_stateless_group_ports(parallel_config, new_data_parallel_size) + + # Phase 1: Send reconfig messages to existing engines reconfig_futures = [] - self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() for engine in self.core_engines: reconfig_request = ReconfigureDistributedRequest( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, + new_data_parallel_master_ip=parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=parallel_config.data_parallel_master_port, + new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list, + new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list, + new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list, + new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list, + new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list, ) coro = self._call_utility_async( "reinitialize_distributed", reconfig_request, engine=engine ) reconfig_futures.append(asyncio.create_task(coro)) - logger.info("All reconfigure messages sent, starting engine creation") - - # Phase 2: Create new engines now that reconfig messages have been sent - # self.resources.engine_manager is guaranteed to be - # CoreEngineActorManager for RayDPClient + # Phase 2: Create new engines assert isinstance(self.resources.engine_manager, CoreEngineActorManager) - self.resources.engine_manager.scale_up_elastic_ep( - self.vllm_config, new_data_parallel_size + parallel_config.eplb_config.num_redundant_experts = 0 + start_new_worker_future = asyncio.to_thread( + self.resources.engine_manager.scale_up_elastic_ep, + self.vllm_config, + new_data_parallel_size, ) + wait_future = self._eep_wait_for_setup_switch_complete() + + # Phase 3: Wait for new engines to be created + # and reconfig messages to be received + await asyncio.gather(start_new_worker_future, *reconfig_futures) + logger.info("[Elastic EP] Successfully started new engines") # Create new CoreEngine objects for the new engines new_engine_identities = set() for i in range(cur_data_parallel_size, new_data_parallel_size): new_engine = i.to_bytes(2, "little") self.core_engines.append(new_engine) + # NOTE(yongji): we don't update lb_engines here, + # we let run_engine_stats_update_task to update it. new_engine_identities.add(new_engine) # Wait for ready messages from new engines on the input socket @@ -1387,10 +1592,11 @@ class DPLBAsyncMPClient(DPAsyncMPClient): identity, _ = sync_input_socket.recv_multipart() new_engine_identities.discard(identity) - # Phase 3: Wait for all existing engines to complete reconfiguration - logger.info("Waiting for existing engines to complete reconfiguration") - await asyncio.gather(*reconfig_futures) - + # NOTE(yongji): Before we schedule any requests on the new workers, + # we should wait for them to switch to the new setup. + await wait_future + # Update the parallel config + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # Notify coordinator about scale up through existing # stats_update_task connection self._ensure_stats_update_task() @@ -1399,8 +1605,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ) await self.first_req_send_socket.send(scale_up_marker) - # Update the parallel config - self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size logger.info( "[Elastic EP] Scale up completed, new data parallel size: %s", new_data_parallel_size, @@ -1413,7 +1617,14 @@ class DPLBAsyncMPClient(DPAsyncMPClient): reconfiguring existing engine cores.""" cur_data_parallel_size = len(self.core_engines) - self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() + self.eep_scaling_cache = ElasticScalingCache( + existing_core_engines=self.core_engines.copy(), + num_new_core_engines=new_data_parallel_size - cur_data_parallel_size, + pending_notifications=dict(), + ) + + parallel_config = self.vllm_config.parallel_config + allocate_stateless_group_ports(parallel_config, new_data_parallel_size) reconfig_futures = [] for cur_dp_rank, engine in enumerate(self.core_engines): @@ -1421,8 +1632,13 @@ class DPLBAsyncMPClient(DPAsyncMPClient): new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, + new_data_parallel_master_ip=parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=parallel_config.data_parallel_master_port, + new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list, + new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list, + new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list, + new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list, + new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list, ) if cur_dp_rank >= new_data_parallel_size: reconfig_request.new_data_parallel_rank = ( @@ -1433,23 +1649,24 @@ class DPLBAsyncMPClient(DPAsyncMPClient): ) reconfig_futures.append(asyncio.create_task(coro)) - for _ in range(new_data_parallel_size, cur_data_parallel_size): - self.core_engines.pop() + # NOTE(yongji): Immediately stop sending requests to the removing engines. + self.core_engines = self.core_engines[:new_data_parallel_size] + self.lb_engines = self.lb_engines[:new_data_parallel_size] + wait_future = self._eep_wait_for_setup_switch_complete() await asyncio.gather(*reconfig_futures) - assert isinstance(self.resources.engine_manager, CoreEngineActorManager) - self.resources.engine_manager.scale_down_elastic_ep( - cur_data_parallel_size, new_data_parallel_size - ) - + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size self._ensure_stats_update_task() scale_down_marker = msgspec.msgpack.encode( ("SCALE_ELASTIC_EP", new_data_parallel_size) ) await self.first_req_send_socket.send(scale_down_marker) - self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size + # NOTE(yongji): Unlike scaling up, + # here we don't actually need to wait for the setup switch to complete. + # We may want to remove it in the future. + await wait_future logger.info( "[Elastic EP] Scale down completed, new data parallel size: %s", new_data_parallel_size, diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 6c11087a3..a7d3c10b5 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -277,6 +277,8 @@ class CoreEngineActorManager: else: ray.init() + vllm_config.parallel_config.allocate_elastic_ep_ports() + if placement_groups is not None: assert local_dp_ranks is not None, ( "local_dp_ranks must be provided if placement_groups is provided" @@ -584,6 +586,8 @@ class CoreEngineActorManager: node_ip = node.node_ip node_id = node.node_id + if device_str not in available_resources[node_id]: + continue available_gpus = int(available_resources[node_id][device_str]) # Get total GPUs on this node from the node's resources @@ -773,11 +777,50 @@ class CoreEngineActorManager: ray.util.remove_placement_group(pg) +def get_engine_zmq_addresses( + vllm_config: VllmConfig, + num_api_servers: int = 1, +) -> EngineZmqAddresses: + """Allocate ZMQ addresses for engine-client communication.""" + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + local_start_index = parallel_config.data_parallel_rank_local + dp_size = parallel_config.data_parallel_size + host = parallel_config.data_parallel_master_ip + local_engines_only = parallel_config.local_engines_only + + # In offline mode there is an LLM instance per DP rank and + # one core engine per LLM, see + # examples/offline_inference/data_parallel.py. + offline_mode = local_start_index is not None + + # client_local_only = True for cases where this front-end + # sends requests only to colocated engines. + client_local_only = ( + offline_mode or local_engines_only or (local_engine_count == dp_size) + ) + # NOTE(yongji): handling scaling from intra-node to inter-node + if parallel_config.enable_elastic_ep: + client_local_only = False + + return EngineZmqAddresses( + inputs=[ + get_engine_client_zmq_addr(client_local_only, host) + for _ in range(num_api_servers) + ], + outputs=[ + get_engine_client_zmq_addr(client_local_only, host) + for _ in range(num_api_servers) + ], + ) + + @contextlib.contextmanager def launch_core_engines( vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, + addresses: EngineZmqAddresses, num_api_servers: int = 1, ) -> Iterator[ tuple[ @@ -796,29 +839,8 @@ def launch_core_engines( host = parallel_config.data_parallel_master_ip local_engines_only = parallel_config.local_engines_only - # In offline mode there is an LLM instance per DP rank and - # one core engine per LLM, see - # examples/offline_inference/data_parallel.py. offline_mode = local_start_index is not None - # client_local_only = True for cases where this front-end - # sends requests only to colocated engines. - client_local_only = ( - offline_mode or local_engines_only or (local_engine_count == dp_size) - ) - - # Set up input and output addresses. - addresses = EngineZmqAddresses( - inputs=[ - get_engine_client_zmq_addr(client_local_only, host) - for _ in range(num_api_servers) - ], - outputs=[ - get_engine_client_zmq_addr(client_local_only, host) - for _ in range(num_api_servers) - ], - ) - # Run the DP Coordinator process with rank 0 when in online DP mode. # The coordinator is needed for: # 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing @@ -885,6 +907,10 @@ def launch_core_engines( # will be False. handshake_local_only = offline_mode or local_engine_count == dp_size + # NOTE(yongji): handling scaling from intra-node to inter-node + if parallel_config.enable_elastic_ep: + handshake_local_only = False + handshake_address = get_engine_client_zmq_addr( handshake_local_only, host, parallel_config.data_parallel_rpc_port ) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 9ea29df00..e3376ba2d 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -38,6 +38,7 @@ from vllm.distributed.parallel_state import ( get_pcp_group, get_pp_group, get_tp_group, + model_parallel_is_initialized, ) from vllm.envs import enable_envs_cache from vllm.logger import init_logger @@ -580,17 +581,20 @@ class WorkerProc: ) self.async_output_copy_thread.start() - # Initialize device - self.worker.init_device() - - # Set process title and log prefix self.setup_proc_title_and_log_prefix( enable_ep=vllm_config.parallel_config.enable_expert_parallel ) # Load model self._init_message_queues(input_shm_handle, vllm_config) - self.worker.load_model() + is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH + if not is_eep_new_worker: + self.worker.init_device() + # Update process title now that parallel groups are initialized + self.setup_proc_title_and_log_prefix( + enable_ep=vllm_config.parallel_config.enable_expert_parallel + ) + self.worker.load_model() # Enable environment variable cache (e.g. assume no more # environment variable overrides after this point) @@ -885,6 +889,13 @@ class WorkerProc: @staticmethod def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: + # Check if parallel groups are initialized first + if not model_parallel_is_initialized(): + # Parallel groups not yet initialized, use default process name + set_process_title(name="Worker") + decorate_logs("Worker") + return + dp_size = get_dp_group().world_size dp_rank = get_dp_group().rank_in_group pp_size = get_pp_group().world_size diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index ad51526ae..200de181a 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -382,8 +382,10 @@ class RayDistributedExecutor(Executor): all_kwargs.append(kwargs) self.collective_rpc("init_worker", args=(all_kwargs,)) - self.collective_rpc("init_device") - self.collective_rpc("load_model") + is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH + if not is_eep_new_worker: + self.collective_rpc("init_device") + self.collective_rpc("load_model") for pp_rank in range(self.parallel_config.pipeline_parallel_size): self.pp_tp_workers.append([]) diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index b9c7b5501..3759c751c 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -14,7 +14,6 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput -from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.serial_utils import run_method @@ -43,9 +42,11 @@ class UniProcExecutor(Executor): max_workers=1, thread_name_prefix="WorkerAsyncOutput" ) + is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH self.driver_worker.init_worker(all_kwargs=[kwargs]) - self.driver_worker.init_device() - self.driver_worker.load_model() + if not is_eep_new_worker: + self.driver_worker.init_device() + self.driver_worker.load_model() def _distributed_args(self) -> tuple[str, int, int]: """Return (distributed_init_method, rank, local_rank).""" @@ -122,16 +123,6 @@ class UniProcExecutor(Executor): # it's running. return - def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest - ) -> None: - self.driver_worker.reinitialize_distributed(reconfig_request) - if ( - reconfig_request.new_data_parallel_rank - == ReconfigureRankType.SHUTDOWN_CURRENT_RANK - ): - self.shutdown() - def shutdown(self) -> None: if worker := self.driver_worker: worker.shutdown() diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 8ee758353..489480004 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -53,7 +53,12 @@ class CPUModelRunner(GPUModelRunner): v.gpu = v.cpu @instrument(span_name="Loading (CPU)") - def load_model(self, eep_scale_up: bool = False) -> None: + def load_model(self, load_dummy_weights: bool = False) -> None: + if load_dummy_weights: + raise ValueError( + "Loading dummy weights (needed for elastic EP scale-up) " + "Is not supported by the CPU Model Runner." + ) logger.info("Starting to load model %s...", self.model_config.model) self.model = get_model(vllm_config=self.vllm_config) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5e8de1429..59a82d4ce 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -461,6 +461,8 @@ class GPUModelRunner( self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.eplb_state: EplbState | None = None + # NOTE(yongji): flag to temporarily disable EPLB during scaling up/down + self.eep_eplb_suppressed = False """ State of the expert parallelism load balancer. @@ -2702,7 +2704,7 @@ class GPUModelRunner( """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ - if not self.parallel_config.enable_eplb: + if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed: return assert self.eplb_state is not None @@ -2714,6 +2716,23 @@ class GPUModelRunner( log_stats=self.parallel_config.eplb_config.log_balancedness, ) + def setup_eplb_from_mapping( + self, + expanded_physical_to_logical: torch.Tensor, + old_num_physical_experts: int, + ) -> None: + model = self.get_model() + assert is_mixture_of_experts(model) + + self.eplb_state = EplbState.from_mapping( + model=model, + model_config=self.model_config, + device=self.device, + parallel_config=self.parallel_config, + expanded_physical_to_logical=expanded_physical_to_logical, + num_valid_physical_experts=old_num_physical_experts, + ) + def _pool( self, hidden_states: torch.Tensor, @@ -4175,21 +4194,16 @@ class GPUModelRunner( setattr(self, config_name, new_config) @instrument(span_name="Loading (GPU)") - def load_model(self, eep_scale_up: bool = False) -> None: + def load_model(self, load_dummy_weights: bool = False) -> None: """ Args: - eep_scale_up: the model loading is for elastic EP scale up. + load_dummy_weights: load dummy weights instead of real weights. """ logger.info_once( "Starting to load model %s...", self.model_config.model, scope="global", ) - global_expert_loads, old_global_expert_indices_per_model, rank_mapping = ( - EplbState.get_eep_state(self.parallel_config) - if eep_scale_up - else (None, None, None) - ) if self.parallel_config.enable_eplb: self.eplb_state = EplbState(self.parallel_config, self.device) @@ -4198,6 +4212,8 @@ class GPUModelRunner( try: with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() + if load_dummy_weights: + self.load_config.load_format = "dummy" model_loader = get_model_loader(self.load_config) self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config @@ -4214,6 +4230,9 @@ class GPUModelRunner( and is_mixture_of_experts(self.drafter.model) and self.parallel_config.enable_eplb ): + assert not self.parallel_config.enable_elastic_ep, ( + "Elastic EP is not supported with drafter model." + ) spec_config = self.vllm_config.speculative_config assert spec_config is not None assert spec_config.draft_model_config is not None @@ -4221,17 +4240,6 @@ class GPUModelRunner( "EPLB is enabled for drafter model %s.", spec_config.draft_model_config.model, ) - - global_expert_load = ( - global_expert_loads[eplb_models] - if global_expert_loads - else None - ) - old_global_expert_indices = ( - old_global_expert_indices_per_model[eplb_models] - if old_global_expert_indices_per_model - else None - ) if self.eplb_state is None: self.eplb_state = EplbState( self.parallel_config, self.device @@ -4239,9 +4247,6 @@ class GPUModelRunner( self.eplb_state.add_model( self.drafter.model, spec_config.draft_model_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, ) eplb_models += 1 @@ -4283,11 +4288,12 @@ class GPUModelRunner( time_after_load - time_before_load, scope="local", ) - prepare_communication_buffer_for_model(self.model) - if (drafter := getattr(self, "drafter", None)) and ( - drafter_model := getattr(drafter, "model", None) - ): - prepare_communication_buffer_for_model(drafter_model) + if not load_dummy_weights: + prepare_communication_buffer_for_model(self.model) + if (drafter := getattr(self, "drafter", None)) and ( + drafter_model := getattr(drafter, "model", None) + ): + prepare_communication_buffer_for_model(drafter_model) mm_config = self.model_config.multimodal_config self.is_multimodal_pruning_enabled = ( supports_multimodal_pruning(self.get_model()) @@ -4295,26 +4301,19 @@ class GPUModelRunner( and mm_config.is_multimodal_pruning_enabled() ) - if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + if ( + is_mixture_of_experts(self.model) + and self.parallel_config.enable_eplb + and not load_dummy_weights + ): logger.info_once("EPLB is enabled for model %s.", self.model_config.model) - global_expert_load = ( - global_expert_loads[eplb_models] if global_expert_loads else None - ) - old_global_expert_indices = ( - old_global_expert_indices_per_model[eplb_models] - if old_global_expert_indices_per_model - else None - ) assert self.eplb_state is not None self.eplb_state.add_model( self.model, self.model_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, ) if self.eplb_state.is_async: - self.eplb_state.start_async_loop(rank_mapping=rank_mapping) + self.eplb_state.start_async_loop() if ( self.vllm_config.compilation_config.mode diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 06410b2eb..07582ad96 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -7,11 +7,10 @@ import os from collections.abc import Callable from contextlib import AbstractContextManager, nullcontext from types import NoneType -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import numpy as np import torch -import torch.distributed import torch.nn as nn import vllm.envs as envs @@ -32,14 +31,12 @@ from vllm.distributed.kv_transfer import ( ) from vllm.distributed.parallel_state import ( Handle, - get_pcp_group, get_pp_group, get_tp_group, ) from vllm.distributed.weight_transfer import WeightTransferEngineFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.platforms import current_platform from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper @@ -49,7 +46,6 @@ from vllm.tracing import instrument from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling from vllm.utils.torch_utils import set_random_seed from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput -from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ( AsyncModelRunnerOutput, @@ -124,6 +120,10 @@ class Worker(WorkerBase): precision = envs.VLLM_FLOAT32_MATMUL_PRECISION torch.set_float32_matmul_precision(precision) + from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor + + self.elastic_ep_executor = ElasticEPScalingExecutor(self) + # Buffers saved before sleep self._sleep_saved_buffers: dict[str, torch.Tensor] = {} @@ -317,12 +317,29 @@ class Worker(WorkerBase): # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: - eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" + dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" + if dummy_weights: + ( + expanded_physical_to_logical, + num_logical_experts, + old_num_physical_experts, + ) = self.elastic_ep_executor.receive_expert_mapping() + num_physical_experts = expanded_physical_to_logical.shape[1] + self.parallel_config.eplb_config.num_redundant_experts = ( + num_physical_experts - num_logical_experts + ) + with ( self._maybe_get_memory_pool_context(tag="weights"), set_current_vllm_config(self.vllm_config), ): - self.model_runner.load_model(eep_scale_up=eep_scale_up) + self.model_runner.load_model(load_dummy_weights=dummy_weights) + + if dummy_weights: + self.model_runner.setup_eplb_from_mapping( + expanded_physical_to_logical, old_num_physical_experts + ) + self.model_runner.eep_eplb_suppressed = True def update_config(self, overrides: dict[str, Any]) -> None: self.model_runner.update_config(overrides) @@ -801,227 +818,6 @@ class Worker(WorkerBase): # worker will always be healthy as long as it's running. return - def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None: - from vllm.distributed.parallel_state import get_ep_group - - if get_ep_group().rank == 0: - logger.info( - "[Elastic EP] Starting expert resharding before scaling down..." - ) - rank_mapping = { - old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 - for old_ep_rank in range(old_ep_size) - } - assert self.model_runner.eplb_state is not None - self.model_runner.eplb_state.rearrange( - execute_shuffle=True, - global_expert_loads=None, - rank_mapping=rank_mapping, - ) - torch.cuda.synchronize() - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Expert resharding completed!") - - def _eplb_after_scale_up( - self, - old_ep_size: int, - new_ep_size: int, - global_expert_loads: list[torch.Tensor] | None, - ) -> None: - from vllm.distributed.parallel_state import get_ep_group - - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Starting expert resharding after scaling up...") - rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} - assert self.model_runner.eplb_state is not None - self.model_runner.eplb_state.rearrange( - execute_shuffle=True, - global_expert_loads=global_expert_loads, - rank_mapping=rank_mapping, - ) - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Expert resharding completed!") - - def _reconfigure_parallel_config( - self, reconfig_request: ReconfigureDistributedRequest - ) -> None: - """ - Update parallel config with provided reconfig_request - """ - parallel_config = self.vllm_config.parallel_config - parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size - if ( - reconfig_request.new_data_parallel_rank - != ReconfigureRankType.KEEP_CURRENT_RANK - ): - parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank - if ( - reconfig_request.new_data_parallel_rank_local - != ReconfigureRankType.KEEP_CURRENT_RANK - ): - parallel_config.data_parallel_rank_local = ( - reconfig_request.new_data_parallel_rank_local - ) - parallel_config.data_parallel_master_ip = ( - reconfig_request.new_data_parallel_master_ip - ) - parallel_config.data_parallel_master_port = ( - reconfig_request.new_data_parallel_master_port - ) - - def _reconfigure_moe( - self, old_ep_size: int, new_ep_size: int - ) -> list[torch.Tensor] | None: - """ - Reconfigure MoE modules with provided reconfig_request - - Return the global expert load if new_ep_size > old_ep_size, - otherwise None - """ - from vllm.distributed.parallel_state import ( - get_dp_group, - get_ep_group, - prepare_communication_buffer_for_model, - ) - from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, - FusedMoEParallelConfig, - ) - - parallel_config = self.vllm_config.parallel_config - - def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]: - return [ - module - for module in model.modules() - if ( - module.__class__.__name__ == "FusedMoE" - or module.__class__.__name__ == "SharedFusedMoE" - ) - ] - - def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int): - assert all( - module.moe_config.num_local_experts == num_local_experts - for module in moe_modules - ), "All MoE modules must have the same number of experts" - for module in moe_modules: - module.moe_config.num_experts = num_local_experts * new_ep_size - module.global_num_experts = module.moe_config.num_experts - tp_size = get_tp_group().world_size - is_sequence_parallel = parallel_config.use_sequence_parallel_moe - sp_size = tp_size if is_sequence_parallel else 1 - module.moe_parallel_config = FusedMoEParallelConfig.make( - tp_size_=tp_size, - pcp_size_=get_pcp_group().world_size, - dp_size_=get_dp_group().world_size, - sp_size_=sp_size, - vllm_parallel_config=parallel_config, - ) - module.moe_config.moe_parallel_config = module.moe_parallel_config - return moe_modules - - model_moe_modules = get_moe_modules(self.model_runner.model) - num_local_experts = model_moe_modules[0].moe_config.num_local_experts - - update_moe_modules(model_moe_modules, num_local_experts) - drafter_model = None - if hasattr(self.model_runner, "drafter") and hasattr( - self.model_runner.drafter, "model" - ): - drafter_model = self.model_runner.drafter.model - if drafter_model is not None and is_mixture_of_experts(drafter_model): - drafter_moe_modules = get_moe_modules(drafter_model) - # Check if drafter and model have matching configs - assert ( - drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts - ), "Drafter and model configs should be the same" - update_moe_modules(drafter_moe_modules, num_local_experts) - - if new_ep_size < old_ep_size: - num_local_physical_experts = num_local_experts - assert self.model_runner.eplb_state is not None - new_physical_experts = ( - self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined] - ) - parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - - self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined] - ) - global_expert_loads = None - else: - num_local_physical_experts_tensor = torch.tensor( - [num_local_experts], dtype=torch.int32, device="cpu" - ) - torch.distributed.broadcast( - num_local_physical_experts_tensor, - group=get_ep_group().cpu_group, - group_src=0, - ) - num_local_physical_experts = int(num_local_physical_experts_tensor.item()) - new_physical_experts = num_local_physical_experts * new_ep_size - assert self.model_runner.eplb_state is not None - global_expert_loads_any = self.model_runner.eplb_state.rearrange( - execute_shuffle=False - ) - global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any) - parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - global_expert_loads[0].shape[1] - ) - prepare_communication_buffer_for_model(self.model_runner.model) - if drafter_model is not None: - prepare_communication_buffer_for_model(drafter_model) - self.model_runner.model.update_physical_experts_metadata( - num_physical_experts=new_physical_experts, - num_local_physical_experts=num_local_physical_experts, - ) - return global_expert_loads - - def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest - ) -> None: - from vllm.config import set_current_vllm_config - from vllm.distributed.parallel_state import ( - cleanup_dist_env_and_memory, - get_ep_group, - ) - - old_ep_size = get_ep_group().world_size - old_ep_rank = get_ep_group().rank - new_ep_size = ( - reconfig_request.new_data_parallel_size - * get_tp_group().world_size - * get_pp_group().world_size - ) - if new_ep_size < old_ep_size: - self._eplb_before_scale_down(old_ep_size, new_ep_size) - - cleanup_dist_env_and_memory() - - if ( - reconfig_request.new_data_parallel_rank - == ReconfigureRankType.SHUTDOWN_CURRENT_RANK - ): - assert old_ep_rank >= new_ep_size - # shutdown - return - - self._reconfigure_parallel_config(reconfig_request) - - with set_current_vllm_config(self.vllm_config): - init_worker_distributed_environment( - self.vllm_config, - self.rank, - self.distributed_init_method, - self.local_rank, - ) - - global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size) - - if new_ep_size > old_ep_size: - assert global_expert_loads is not None - self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads) - def save_sharded_state( self, path: str, @@ -1118,6 +914,9 @@ class Worker(WorkerBase): if weight_transfer_engine := getattr(self, "weight_transfer_engine", None): weight_transfer_engine.shutdown() + def elastic_ep_execute(self, execute_method: str, *args, **kwargs): + return self.elastic_ep_executor.execute(execute_method, *args, **kwargs) + def init_worker_distributed_environment( vllm_config: VllmConfig, diff --git a/vllm/v1/worker/workspace.py b/vllm/v1/worker/workspace.py index ef32a32f6..28ba85a26 100644 --- a/vllm/v1/worker/workspace.py +++ b/vllm/v1/worker/workspace.py @@ -66,6 +66,23 @@ class WorkspaceManager: ], ) + def unlock(self) -> None: + """Unlock the workspace to allow growth. + + This is used during elastic EP scaling when the workspace size + needs to grow due to changes in the number of experts. + """ + self._locked = False + if envs.VLLM_DEBUG_WORKSPACE: + logger.info( + "[WORKSPACE DEBUG] Workspace unlocked. Current sizes: %s", + [ + self._workspace_size_bytes(ws) / _MB + for ws in self._current_workspaces + if ws is not None + ], + ) + def is_locked(self) -> bool: """Check if workspace is locked.""" return self._locked @@ -242,6 +259,17 @@ def lock_workspace() -> None: current_workspace_manager().lock() +def unlock_workspace() -> None: + """Unlock the workspace to allow growth. + + This is used during elastic EP scaling when the workspace size + needs to grow due to changes in the number of experts. + After scaling operations complete, lock_workspace() should be + called again to prevent unexpected allocations. + """ + current_workspace_manager().unlock() + + def reset_workspace_manager() -> None: """Reset the workspace manager to uninitialized state.