diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index ffdf4b83c..65701b78b 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -278,7 +278,8 @@ steps: - popd # NEW rlhf examples - pushd ../examples/offline_inference/new_weight_syncing - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - popd diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index 9b5b002f4..0a75bc50e 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -103,7 +103,8 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py # NEW rlhf examples - cd new_weight_syncing - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - label: Distributed Tests (8 GPUs)(H100) timeout_in_minutes: 10 diff --git a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py index 8714eb92b..88b89fbfc 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py @@ -42,6 +42,7 @@ from vllm.distributed.weight_transfer.base import ( WeightTransferUpdateRequest, ) from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine, NCCLWeightTransferInitInfo, NCCLWeightTransferUpdateInfo, @@ -152,11 +153,14 @@ class TrainModel: def broadcast_weights(self, packed: bool = True): """Broadcast weights to the inference engine.""" - NCCLWeightTransferEngine.trainer_send_weights( - iterator=self.model.named_parameters(), + trainer_args = NCCLTrainerSendWeightsArgs( group=self.model_update_group, packed=packed, ) + NCCLWeightTransferEngine.trainer_send_weights( + iterator=self.model.named_parameters(), + trainer_args=trainer_args, + ) @torch.inference_mode() def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]: diff --git a/examples/offline_inference/new_weight_syncing/rlhf_ipc.py b/examples/offline_inference/new_weight_syncing/rlhf_ipc.py new file mode 100644 index 000000000..169b1026a --- /dev/null +++ b/examples/offline_inference/new_weight_syncing/rlhf_ipc.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray, +with IPC-based weight syncing APIs + +The script colocates the training and inference workloads onto the same GPU using Ray. + +The example performs the following steps: + +* Request a placement group of 1 GPU. +* Place the inference model on the above GPU using the placement group. +* Place and load the training model on the same GPU using the placement group. +* Generate text from a list of prompts using the inference engine. +* Update the weights of the training model and broadcast the updated weights + to the inference engine by using CUDA IPC handles. Note that + for demonstration purposes we simply zero out the weights. + +This example assumes a single-node cluster with a single GPU, +but can be extended to multiple GPUs. +""" + +import os + +import ray +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from transformers import AutoModelForCausalLM + +from vllm import LLM, SamplingParams +from vllm.config import WeightTransferConfig +from vllm.distributed.weight_transfer.ipc_engine import ( + IPCTrainerSendWeightsArgs, + IPCWeightTransferEngine, +) + + +class MyLLM(LLM): + """Configure the vLLM worker for Ray placement group execution.""" + + def __init__(self, *args, **kwargs): + # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray + # so that vLLM can manage its own device placement within the worker. + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + # Each worker uses 0.4 GPU so that two instances fit on the same GPU. + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" + os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0" + # needed for ipc handle serialization + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + super().__init__(*args, **kwargs) + + +# Load the OPT-125M model onto GPU 0 for the training workload. + +MODEL_NAME = "facebook/opt-125m" + + +@ray.remote +class TrainModel: + def __init__(self, llm_handle: ray.actor.ActorHandle): + self.train_model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + ) + self.train_model.to("cuda:0") + self.llm_handle = llm_handle + + def init_weight_transfer(self): + # IPC backend doesn't need initialization info + ray.get( + self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict())) + ) + + def broadcast_weights(self, llm_handle: ray.actor.ActorHandle): + """Broadcast weights to the inference engine using IPC.""" + self.llm_handle = llm_handle + trainer_args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle) + IPCWeightTransferEngine.trainer_send_weights( + iterator=self.train_model.named_parameters(), + trainer_args=trainer_args, + ) + + +ray.init() + +pg_colocate = placement_group([{"GPU": 1, "CPU": 0}]) +ray.get(pg_colocate.ready()) + + +llm = ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg_colocate, + placement_group_capture_child_tasks=True, + ), +)(MyLLM).remote( + model=MODEL_NAME, + enforce_eager=True, + tensor_parallel_size=1, + distributed_executor_backend="ray", + gpu_memory_utilization=0.7, + weight_transfer_config=WeightTransferConfig(backend="ipc"), + load_format="dummy", +) + +train_model = TrainModel.options( + num_gpus=0.1, + num_cpus=0, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg_colocate, placement_group_capture_child_tasks=True + ), +).remote(llm) + + +# Generate text from the prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0) + +outputs = ray.get(llm.generate.remote(prompts, sampling_params)) + +print("-" * 50) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + +ray.get(llm.sleep.remote(level=0)) + +ray.get(train_model.init_weight_transfer.remote()) +# Synchronize the updated weights to the inference engine using batched API. +ray.get(train_model.broadcast_weights.remote(llm)) + +ray.get(llm.wake_up.remote(tags=["scheduling"])) + +# Generate text with the updated model. +outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) +print("-" * 50) +for output in outputs_updated: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) diff --git a/examples/offline_inference/new_weight_syncing/rlhf.py b/examples/offline_inference/new_weight_syncing/rlhf_nccl.py similarity index 97% rename from examples/offline_inference/new_weight_syncing/rlhf.py rename to examples/offline_inference/new_weight_syncing/rlhf_nccl.py index b3a3ca62f..5d5f24a93 100644 --- a/examples/offline_inference/new_weight_syncing/rlhf.py +++ b/examples/offline_inference/new_weight_syncing/rlhf_nccl.py @@ -36,6 +36,7 @@ from transformers import AutoModelForCausalLM from vllm import LLM, SamplingParams from vllm.config import WeightTransferConfig from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine, ) from vllm.utils.network_utils import get_ip, get_open_port @@ -90,11 +91,14 @@ class TrainModel: def broadcast_weights(self, packed: bool = True): """Broadcast weights to the inference engine.""" - NCCLWeightTransferEngine.trainer_send_weights( - iterator=self.model.named_parameters(), + trainer_args = NCCLTrainerSendWeightsArgs( group=self.model_update_group, packed=packed, ) + NCCLWeightTransferEngine.trainer_send_weights( + iterator=self.model.named_parameters(), + trainer_args=trainer_args, + ) # Initialize Ray and set the visible devices. The vLLM engine will @@ -156,6 +160,8 @@ for output in outputs: print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") print("-" * 50) +ray.get(llm.sleep.remote(level=0)) + # Set up the communication channel between the training process and the # inference engine. master_address, master_port = ray.get(train_model.get_master_address_and_port.remote()) @@ -197,6 +203,8 @@ inference_handle = llm.update_weights.remote( train_handle = train_model.broadcast_weights.remote(packed=True) ray.get([train_handle, inference_handle]) +ray.get(llm.wake_up.remote(tags=["scheduling"])) + # Generate text with the updated model. The output is expected to be normal # because the weights are updated. outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) diff --git a/examples/online_serving/new_weight_syncing/rlhf_http_ipc.py b/examples/online_serving/new_weight_syncing/rlhf_http_ipc.py new file mode 100644 index 000000000..d73eba64c --- /dev/null +++ b/examples/online_serving/new_weight_syncing/rlhf_http_ipc.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates reinforcement learning from human feedback (RLHF) using vLLM +via HTTP API, with IPC-based weight syncing APIs. + +Unlike rlhf_nccl.py which uses NCCL and can use separate GPUs, this script +uses CUDA IPC which requires the training model and vLLM server to be on the +same GPU. Memory must be carefully managed to fit both models. + +Unlike rlhf.py which creates a vLLM instance programmatically, this script +assumes you have already started a vLLM server using `vllm serve`. It uses: +- OpenAI-compatible API for inference requests +- HTTP endpoints for weight transfer control plane +- CUDA IPC for actual weight data transfer + +Prerequisites: + Start a vLLM server with weight transfer enabled and reduced GPU memory + utilization to leave room for the training model: + + $ VLLM_SERVER_DEV_MODE=1 VLLM_ALLOW_INSECURE_SERIALIZATION=1 \ + vllm serve facebook/opt-125m --enforce-eager \ + --weight-transfer-config '{"backend": "ipc"}' \ + --load-format dummy \ + --gpu-memory-utilization 0.5 + + Then run this script: + + $ python rlhf_http_ipc.py + +The example performs the following steps: + +* Load the training model on GPU 0 (same GPU as the vLLM server). +* Generate text using the vLLM server via OpenAI-compatible API. The output + is expected to be nonsense because the server is initialized with dummy weights. +* Initialize weight transfer via HTTP endpoint (no-op for IPC). +* Broadcast the real weights from the training model to the vLLM server + using CUDA IPC handles. +* Generate text again to show normal output after the weight update. +""" + +import os + +import requests +import torch +from openai import OpenAI +from transformers import AutoModelForCausalLM + +from vllm.distributed.weight_transfer.ipc_engine import ( + IPCTrainerSendWeightsArgs, + IPCWeightTransferEngine, +) + +BASE_URL = "http://localhost:8000" +MODEL_NAME = "facebook/opt-125m" + +# Enable insecure serialization for IPC handle serialization +os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + +def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]: + """Generate completions using the OpenAI-compatible API.""" + results = [] + for prompt in prompts: + response = client.completions.create( + model=model, + prompt=prompt, + max_tokens=32, + temperature=0, + ) + results.append(response.choices[0].text) + return results + + +def init_weight_transfer_engine(base_url: str) -> None: + """Initialize weight transfer via HTTP endpoint (no-op for IPC).""" + url = f"{base_url}/init_weight_transfer_engine" + payload = {"init_info": dict()} + response = requests.post(url, json=payload, timeout=60) + response.raise_for_status() + + +def pause_generation(base_url: str) -> None: + """Pause generation via HTTP endpoint.""" + url = f"{base_url}/pause" + response = requests.post(url, timeout=60) + response.raise_for_status() + + +def resume_generation(base_url: str) -> None: + """Resume generation via HTTP endpoint.""" + url = f"{base_url}/resume" + response = requests.post(url, timeout=60) + response.raise_for_status() + + +def get_world_size(base_url: str) -> int: + """Get world size from the vLLM server.""" + url = f"{base_url}/get_world_size" + response = requests.get(url, timeout=10) + response.raise_for_status() + return response.json()["world_size"] + + +def main(): + # IPC requires the training model to be on the same GPU as the vLLM server + # The server should be started on GPU 0 with reduced memory utilization + device = "cuda:0" + torch.cuda.set_device(device) + + # Load the training model on the same GPU as the server + # Use bfloat16 to reduce memory footprint + print(f"Loading training model: {MODEL_NAME} on {device}") + print( + "Note: Ensure the vLLM server was started with --gpu-memory-utilization 0.5 " + "or lower to leave room for the training model." + ) + train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16) + train_model.to(device) + train_model.eval() # Set to eval mode to save memory + + # Create OpenAI client pointing to the vLLM server + client = OpenAI( + base_url=f"{BASE_URL}/v1", + api_key="EMPTY", # vLLM doesn't require an API key by default + ) + + # Test prompts + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Generate text before weight update. The output is expected to be nonsense + # because the server is initialized with dummy weights. + print("-" * 50) + print("Generating text BEFORE weight update (expect nonsense):") + print("-" * 50) + outputs = generate_completions(client, MODEL_NAME, prompts) + for prompt, generated_text in zip(prompts, outputs): + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + print("Initializing weight transfer (IPC backend)...") + + # Initialize weight transfer on vLLM server (no-op for IPC, but still required) + init_weight_transfer_engine(BASE_URL) + + # Pause generation before weight sync + pause_generation(BASE_URL) + + # Broadcast weights via IPC handles using HTTP mode + print("Broadcasting weights via CUDA IPC (HTTP)...") + trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL) + IPCWeightTransferEngine.trainer_send_weights( + iterator=train_model.named_parameters(), + trainer_args=trainer_args, + ) + + # Resume generation after weight sync + resume_generation(BASE_URL) + + # Generate text after weight update. The output is expected to be normal + # because the real weights are now loaded. + print("-" * 50) + print("Generating text AFTER weight update:") + print("-" * 50) + outputs_updated = generate_completions(client, MODEL_NAME, prompts) + for prompt, generated_text in zip(prompts, outputs_updated): + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") + print("-" * 50) + + # Note: The training model and IPC handles remain in memory. + # In a real RLHF training loop, you would update the training model + # and create new IPC handles for each weight update. + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/rlhf_http.py b/examples/online_serving/new_weight_syncing/rlhf_http_nccl.py similarity index 98% rename from examples/online_serving/rlhf_http.py rename to examples/online_serving/new_weight_syncing/rlhf_http_nccl.py index 721a038a6..b8a6b180a 100644 --- a/examples/online_serving/rlhf_http.py +++ b/examples/online_serving/new_weight_syncing/rlhf_http_nccl.py @@ -39,6 +39,7 @@ from openai import OpenAI from transformers import AutoModelForCausalLM from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLTrainerSendWeightsArgs, NCCLWeightTransferEngine, ) from vllm.utils.network_utils import get_ip, get_open_port @@ -214,11 +215,14 @@ def main(): # Broadcast all weights from trainer to vLLM workers print("Broadcasting weights via NCCL...") - NCCLWeightTransferEngine.trainer_send_weights( - iterator=train_model.named_parameters(), + trainer_args = NCCLTrainerSendWeightsArgs( group=model_update_group, packed=True, ) + NCCLWeightTransferEngine.trainer_send_weights( + iterator=train_model.named_parameters(), + trainer_args=trainer_args, + ) # Wait for update_weights to complete update_thread.join() diff --git a/tests/distributed/test_weight_transfer.py b/tests/distributed/test_weight_transfer.py index 4c348dd79..04747e732 100644 --- a/tests/distributed/test_weight_transfer.py +++ b/tests/distributed/test_weight_transfer.py @@ -3,18 +3,26 @@ """Tests for weight transfer engine backends. Unit tests for engine classes (parsing, validation, registry). -Integration test for NCCL weight transfer between processes using Ray. +Integration tests for NCCL and IPC weight transfer between processes using Ray. """ +import base64 +import pickle from unittest.mock import MagicMock import pytest import ray import torch +from torch.multiprocessing.reductions import reduce_tensor from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig from vllm.distributed.weight_transfer import WeightTransferEngineFactory +from vllm.distributed.weight_transfer.ipc_engine import ( + IPCWeightTransferEngine, + IPCWeightTransferInitInfo, + IPCWeightTransferUpdateInfo, +) from vllm.distributed.weight_transfer.nccl_engine import ( NCCLWeightTransferEngine, NCCLWeightTransferInitInfo, @@ -155,9 +163,29 @@ class TestEngineRegistry: engine = WeightTransferEngineFactory.create_engine(config, parallel_config) assert isinstance(engine, NCCLWeightTransferEngine) + def test_create_engine_ipc(self): + """Test factory creates IPC engine.""" + config = WeightTransferConfig(backend="ipc") + parallel_config = create_mock_parallel_config() + engine = WeightTransferEngineFactory.create_engine(config, parallel_config) + assert isinstance(engine, IPCWeightTransferEngine) + def test_create_engine_invalid_backend(self): """Test factory raises for invalid backend.""" - config = WeightTransferConfig(backend="invalid") + # Pydantic validates Literal types at construction, so we can't create + # a config with an invalid backend. Instead, we test by directly + # accessing the registry or using model_construct to bypass validation. + from pydantic import ValidationError + + # Test that Pydantic prevents invalid backend at construction + with pytest.raises(ValidationError): + WeightTransferConfig(backend="invalid") + + # Test factory error by creating a config with valid backend but + # then manually modifying the backend attribute (bypassing validation) + config = WeightTransferConfig(backend="nccl") + # Use object.__setattr__ to bypass Pydantic validation + object.__setattr__(config, "backend", "invalid") parallel_config = create_mock_parallel_config() with pytest.raises(ValueError, match="Invalid weight transfer backend"): WeightTransferEngineFactory.create_engine(config, parallel_config) @@ -344,3 +372,426 @@ def test_nccl_weight_transfer_between_processes(): f"Received shape: {result['received_shape']}, " f"Received sum: {result['received_sum']}" ) + + +# --- Unit Tests: IPCWeightTransferUpdateInfo Validation --- + + +class TestIPCWeightTransferUpdateInfoValidation: + """Test IPCWeightTransferUpdateInfo dataclass validation.""" + + def test_valid_update_info(self): + """Test creating valid IPCWeightTransferUpdateInfo.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + # Create a dummy tensor and IPC handle + dummy_tensor = torch.ones(10, 10, device="cuda:0") + ipc_handle = reduce_tensor(dummy_tensor) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle}] + + info = IPCWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + ipc_handles=ipc_handles, + ) + assert info.names == ["layer.weight"] + assert info.dtype_names == ["float32"] + assert info.shapes == [[10, 10]] + assert len(info.ipc_handles) == 1 + + def test_mismatched_dtype_names_raises(self): + """Test that mismatched dtype_names length raises ValueError.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + dummy_tensor = torch.ones(10, 10, device="cuda:0") + ipc_handle = reduce_tensor(dummy_tensor) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}] + + with pytest.raises(ValueError, match="dtype_names"): + IPCWeightTransferUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32"], # Only one dtype + shapes=[[10, 10], [10]], + ipc_handles=ipc_handles, + ) + + def test_mismatched_shapes_raises(self): + """Test that mismatched shapes length raises ValueError.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + dummy_tensor = torch.ones(10, 10, device="cuda:0") + ipc_handle = reduce_tensor(dummy_tensor) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}] + + with pytest.raises(ValueError, match="shapes"): + IPCWeightTransferUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32", "float32"], + shapes=[[10, 10]], # Only one shape + ipc_handles=ipc_handles, + ) + + def test_mismatched_ipc_handles_raises(self): + """Test that mismatched ipc_handles length raises ValueError.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + dummy_tensor = torch.ones(10, 10, device="cuda:0") + ipc_handle = reduce_tensor(dummy_tensor) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle}] # Only one handle + + with pytest.raises(ValueError, match="ipc_handles"): + IPCWeightTransferUpdateInfo( + names=["layer.weight", "layer.bias"], + dtype_names=["float32", "float32"], + shapes=[[10, 10], [10]], + ipc_handles=ipc_handles, + ) + + def test_valid_update_info_from_pickled(self): + """Test creating IPCWeightTransferUpdateInfo from pickled handles.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + dummy_tensor = torch.ones(10, 10, device="cuda:0") + ipc_handle = reduce_tensor(dummy_tensor) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle}] + + pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8") + + info = IPCWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + ipc_handles_pickled=pickled, + ) + assert info.ipc_handles == ipc_handles + assert info.ipc_handles_pickled is None + + def test_both_handles_and_pickled_raises(self): + """Test that providing both ipc_handles and ipc_handles_pickled raises.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + dummy_tensor = torch.ones(10, 10, device="cuda:0") + ipc_handle = reduce_tensor(dummy_tensor) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle}] + + pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8") + + with pytest.raises(ValueError, match="Cannot specify both"): + IPCWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + ipc_handles=ipc_handles, + ipc_handles_pickled=pickled, + ) + + def test_neither_handles_nor_pickled_raises(self): + """Test that providing neither ipc_handles nor ipc_handles_pickled raises.""" + with pytest.raises(ValueError, match="must be provided"): + IPCWeightTransferUpdateInfo( + names=["layer.weight"], + dtype_names=["float32"], + shapes=[[10, 10]], + ) + + def test_empty_lists_valid(self): + """Test that empty lists are valid.""" + info = IPCWeightTransferUpdateInfo( + names=[], + dtype_names=[], + shapes=[], + ipc_handles=[], + ) + assert len(info.names) == 0 + + +# --- Unit Tests: IPC Engine Parsing --- + + +class TestIPCEngineParsing: + """Test IPCWeightTransferEngine parsing methods.""" + + def test_parse_update_info_valid(self): + """Test parsing valid update info dict.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + config = WeightTransferConfig(backend="ipc") + parallel_config = create_mock_parallel_config() + engine = IPCWeightTransferEngine(config, parallel_config) + + # Create dummy IPC handles + dummy_tensor1 = torch.ones(100, 100, device="cuda:0") + dummy_tensor2 = torch.ones(50, device="cuda:0") + ipc_handle1 = reduce_tensor(dummy_tensor1) + ipc_handle2 = reduce_tensor(dummy_tensor2) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}] + + update_info = engine.parse_update_info( + { + "names": ["w1", "w2"], + "dtype_names": ["float32", "bfloat16"], + "shapes": [[100, 100], [50]], + "ipc_handles": ipc_handles, + } + ) + + assert isinstance(update_info, IPCWeightTransferUpdateInfo) + assert update_info.names == ["w1", "w2"] + assert update_info.dtype_names == ["float32", "bfloat16"] + assert update_info.shapes == [[100, 100], [50]] + assert len(update_info.ipc_handles) == 2 + + def test_parse_update_info_pickled(self): + """Test parsing update info with pickled IPC handles (HTTP path).""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + config = WeightTransferConfig(backend="ipc") + parallel_config = create_mock_parallel_config() + engine = IPCWeightTransferEngine(config, parallel_config) + + dummy_tensor1 = torch.ones(100, 100, device="cuda:0") + dummy_tensor2 = torch.ones(50, device="cuda:0") + ipc_handle1 = reduce_tensor(dummy_tensor1) + ipc_handle2 = reduce_tensor(dummy_tensor2) + gpu_uuid = str(torch.cuda.get_device_properties(0).uuid) + ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}] + + pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8") + + update_info = engine.parse_update_info( + { + "names": ["w1", "w2"], + "dtype_names": ["float32", "bfloat16"], + "shapes": [[100, 100], [50]], + "ipc_handles_pickled": pickled, + } + ) + + assert isinstance(update_info, IPCWeightTransferUpdateInfo) + assert update_info.names == ["w1", "w2"] + assert len(update_info.ipc_handles) == 2 + assert update_info.ipc_handles_pickled is None + assert gpu_uuid in update_info.ipc_handles[0] + assert gpu_uuid in update_info.ipc_handles[1] + + +# --- Integration Test: IPC Weight Transfer Between Ray Tasks --- + + +def get_physical_gpu_id(device_index: int = 0) -> str: + """Get physical GPU UUID for a device.""" + props = torch.cuda.get_device_properties(device_index) + return str(props.uuid) + + +@ray.remote(num_gpus=0.5) +class TrainerActor: + """Trainer actor that creates and holds CUDA IPC handles.""" + + def __init__(self, tensor_shape: list[int], tensor_dtype: str): + # Create tensor on GPU and keep it alive + dtype = getattr(torch, tensor_dtype) + self.tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda:0") + self.tensor.fill_(42.0) # Fill with 42 to verify correct transfer + + # Create IPC handle (tensor must stay alive for IPC to work) + ipc_handle = reduce_tensor(self.tensor) + gpu_uuid = get_physical_gpu_id(0) + + torch.cuda.synchronize() + + self.ipc_handle_dict = { + "ipc_handle": ipc_handle, + "gpu_uuid": gpu_uuid, + "shape": tensor_shape, + "dtype": tensor_dtype, + } + + def get_ipc_handle_dict(self) -> dict: + """Return IPC handle dict. Tensor stays alive in this actor.""" + return self.ipc_handle_dict + + +@ray.remote(num_gpus=0.5) +def inference_receive_ipc_tensor( + ipc_handle_dict: dict, + mode: str = "ray", +) -> dict: + """Inference task that receives tensor via IPCWeightTransferEngine.""" + from unittest.mock import MagicMock + + import torch + + from vllm.config.parallel import ParallelConfig + from vllm.config.weight_transfer import WeightTransferConfig + from vllm.distributed.weight_transfer.ipc_engine import ( + IPCWeightTransferEngine, + ) + + # Create engine with mock parallel config + config = WeightTransferConfig(backend="ipc") + parallel_config = MagicMock(spec=ParallelConfig) + parallel_config.rank = 0 + parallel_config.world_size = 1 + parallel_config.data_parallel_rank = 0 + + engine = IPCWeightTransferEngine(config, parallel_config) + + # Initialize the engine (no-op for IPC) + init_info = IPCWeightTransferInitInfo() + engine.init_transfer_engine(init_info) + + # Receive weights with a no-op load_weights that captures the tensor + received_tensors = [] + + def noop_load_weights(weights: list[tuple[str, torch.Tensor]]): + for name, tensor in weights: + # Clone tensor to keep it after engine cleans up + received_tensors.append((name, tensor.clone())) + + # Build update dict and go through parse_update_info (exercises __post_init__) + ipc_handles = [{ipc_handle_dict["gpu_uuid"]: ipc_handle_dict["ipc_handle"]}] + + if mode == "ray": + update_dict: dict = { + "names": ["test.weight"], + "dtype_names": [ipc_handle_dict["dtype"]], + "shapes": [ipc_handle_dict["shape"]], + "ipc_handles": ipc_handles, + } + elif mode == "http": + pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8") + update_dict = { + "names": ["test.weight"], + "dtype_names": [ipc_handle_dict["dtype"]], + "shapes": [ipc_handle_dict["shape"]], + "ipc_handles_pickled": pickled, + } + else: + raise ValueError(f"Unknown mode: {mode}") + + update_info = engine.parse_update_info(update_dict) + engine.receive_weights(update_info, noop_load_weights) + torch.cuda.synchronize() + + # Verify we received the tensor + success = False + received_shape = None + received_sum = None + + if len(received_tensors) == 1: + name, tensor = received_tensors[0] + received_shape = list(tensor.shape) + received_sum = tensor.sum().item() + # Check shape matches and values are all 42s (trainer sends 42s) + if received_shape == ipc_handle_dict["shape"]: + expected_sum = 42.0 * torch.tensor(ipc_handle_dict["shape"]).prod().item() + if abs(received_sum - expected_sum) < 0.01: + success = True + + engine.shutdown() + + return { + "success": success, + "received_shape": received_shape, + "received_sum": received_sum, + } + + +@pytest.mark.skipif( + torch.cuda.device_count() < 1, + reason="Need at least 1 GPU to run IPC weight transfer test.", +) +@pytest.mark.parametrize("mode", ["ray", "http"]) +def test_ipc_weight_transfer_between_processes(mode: str): + """Test IPC weight transfer from trainer to inference process using Ray. + + Parametrized over transport modes: + - 'ray': ipc_handles passed directly. + - 'http': ipc_handles pickled + base64-encoded, unpickled via __post_init__. + + IPC requires same-GPU access, so we use a placement group to co-locate + the trainer actor and inference task on the same GPU. + """ + from ray.util.placement_group import placement_group + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + + ray.init(ignore_reinit_error=True) + + # Create a placement group to ensure both processes are on the same GPU + # Use fractional GPUs so both tasks can share the same GPU bundle + pg = placement_group([{"GPU": 1, "CPU": 2}]) + ray.get(pg.ready()) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_capture_child_tasks=True, + ) + + # Tensor to transfer: 100x100 filled with 42s + tensor_shape = [100, 100] + tensor_dtype = "float32" + + # Create trainer actor that holds the tensor and IPC handle (stays alive) + trainer_actor = TrainerActor.options( # type: ignore[attr-defined] + scheduling_strategy=scheduling_strategy + ).remote(tensor_shape, tensor_dtype) + + # Get IPC handle dict (tensor stays alive in trainer actor) + ipc_handle_dict = ray.get(trainer_actor.get_ipc_handle_dict.remote()) + + # Receive tensor in inference process using IPC handles (on same GPU) + # Trainer actor stays alive during this operation + inference_result = ray.get( + inference_receive_ipc_tensor.options( + scheduling_strategy=scheduling_strategy + ).remote(ipc_handle_dict, mode=mode) + ) + + assert inference_result["success"], ( + f"IPC weight transfer failed (mode={mode}). " + f"Received shape: {inference_result['received_shape']}, " + f"Received sum: {inference_result['received_sum']}" + ) + + +def test_ipc_receive_weights_missing_gpu_uuid_raises(): + """Test that receive_weights raises if GPU UUID not found in IPC handles.""" + if torch.cuda.device_count() < 1: + pytest.skip("Need at least 1 GPU for this test") + + config = WeightTransferConfig(backend="ipc") + parallel_config = create_mock_parallel_config() + engine = IPCWeightTransferEngine(config, parallel_config) + + # Create IPC handle with wrong GPU UUID + dummy_tensor = torch.ones(10, 10, device="cuda:0") + ipc_handle = reduce_tensor(dummy_tensor) + wrong_uuid = "wrong-uuid-12345" + ipc_handles = [{wrong_uuid: ipc_handle}] + + update_info = IPCWeightTransferUpdateInfo( + names=["w"], + dtype_names=["float32"], + shapes=[[10, 10]], + ipc_handles=ipc_handles, + ) + + with pytest.raises(ValueError, match="IPC handle not found"): + engine.receive_weights(update_info, lambda x: None) diff --git a/tools/pre_commit/check_forbidden_imports.py b/tools/pre_commit/check_forbidden_imports.py index 009e9bcbc..786610138 100644 --- a/tools/pre_commit/check_forbidden_imports.py +++ b/tools/pre_commit/check_forbidden_imports.py @@ -37,6 +37,8 @@ CHECK_IMPORTS = { "vllm/distributed/device_communicators/all_reduce_utils.py", "vllm/distributed/device_communicators/shm_broadcast.py", "vllm/distributed/device_communicators/shm_object_storage.py", + "vllm/distributed/weight_transfer/ipc_engine.py", + "tests/distributed/test_weight_transfer.py", "vllm/utils/hashing.py", "tests/multimodal/media/test_base.py", "tests/tokenizers_/test_hf.py", diff --git a/vllm/config/weight_transfer.py b/vllm/config/weight_transfer.py index 855b0d915..1da1f96cb 100644 --- a/vllm/config/weight_transfer.py +++ b/vllm/config/weight_transfer.py @@ -9,5 +9,5 @@ from vllm.config.utils import config class WeightTransferConfig: """Configuration for weight transfer during RL training.""" - backend: Literal["nccl"] = "nccl" + backend: Literal["nccl", "ipc"] = "nccl" """The backend to use for weight transfer.""" diff --git a/vllm/distributed/weight_transfer/base.py b/vllm/distributed/weight_transfer/base.py index b87f190fc..788dcef12 100644 --- a/vllm/distributed/weight_transfer/base.py +++ b/vllm/distributed/weight_transfer/base.py @@ -3,7 +3,7 @@ """Base class for weight transfer engines.""" from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Iterator from dataclasses import KW_ONLY, dataclass, field from typing import Any, Generic, TypeVar @@ -156,3 +156,30 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]): This should be called when the worker is shutting down. """ raise NotImplementedError + + @staticmethod + @abstractmethod + def trainer_send_weights( + iterator: Iterator[tuple[str, torch.Tensor]], + trainer_args: dict[str, Any] | Any, + ) -> None: + """ + Send weights from trainer to inference workers. + + This is a static method that can be called from the trainer process + to send weights to all inference workers. + + Args: + iterator: Iterator of model parameters. Returns (name, tensor) tuples. + The tensors should be on the appropriate device for the backend. + trainer_args: Dictionary containing backend-specific arguments needed + to send weights. The structure depends on the backend: + - NCCL: Contains 'group', 'src', 'packed', etc. + - IPC: Contains 'mode' ('http' or 'ray'), + 'llm_handle' (for Ray), 'url' (for HTTP), etc. + + Example: + >>> param_iter = ((n, p) for n, p in model.named_parameters()) + >>> engine.trainer_send_weights(param_iter, trainer_args) + """ + raise NotImplementedError diff --git a/vllm/distributed/weight_transfer/factory.py b/vllm/distributed/weight_transfer/factory.py index 7235e30d1..f8e9c864f 100644 --- a/vllm/distributed/weight_transfer/factory.py +++ b/vllm/distributed/weight_transfer/factory.py @@ -114,3 +114,9 @@ WeightTransferEngineFactory.register_engine( "vllm.distributed.weight_transfer.nccl_engine", "NCCLWeightTransferEngine", ) + +WeightTransferEngineFactory.register_engine( + "ipc", + "vllm.distributed.weight_transfer.ipc_engine", + "IPCWeightTransferEngine", +) diff --git a/vllm/distributed/weight_transfer/ipc_engine.py b/vllm/distributed/weight_transfer/ipc_engine.py new file mode 100644 index 000000000..2edbec625 --- /dev/null +++ b/vllm/distributed/weight_transfer/ipc_engine.py @@ -0,0 +1,291 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""IPC-based weight transfer engine using CUDA IPC for communication.""" + +import base64 +import pickle +from collections.abc import Callable, Iterator +from dataclasses import asdict, dataclass +from typing import Any + +import requests +import torch +from torch.multiprocessing.reductions import reduce_tensor + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) + + +@dataclass +class IPCTrainerSendWeightsArgs: + """Arguments for IPC trainer_send_weights method.""" + + mode: str + """Transport mode: 'http' or 'ray'.""" + llm_handle: Any = None + """Ray ObjectRef to LLM handle (required for 'ray' mode).""" + url: str | None = None + """Base URL for HTTP endpoint (required for 'http' mode).""" + + def __post_init__(self): + """Validate that required arguments are provided for the selected mode.""" + if self.mode == "ray" and self.llm_handle is None: + raise ValueError("llm_handle is required for 'ray' mode") + if self.mode == "http" and self.url is None: + raise ValueError("url is required for 'http' mode") + if self.mode not in ("ray", "http"): + raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}") + + +@dataclass +class IPCWeightTransferInitInfo(WeightTransferInitInfo): + """Initialization info for IPC weight transfer backend. No init needed for IPC.""" + + pass + + +@dataclass +class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo): + """Update info for IPC weight transfer backend. + + Accepts IPC handles either directly via ``ipc_handles`` (Ray transport) + or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport). + Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set + it is unpickled into ``ipc_handles`` during ``__post_init__``. + """ + + names: list[str] + dtype_names: list[str] + shapes: list[list[int]] + ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None + """IPC handles mapping physical GPU UUID to (func, args) tuple. + Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples.""" + ipc_handles_pickled: str | None = None + """Base64-encoded pickled IPC handles, used for HTTP transport.""" + + def __post_init__(self): + if self.ipc_handles_pickled is not None: + if self.ipc_handles is not None: + raise ValueError( + "Cannot specify both `ipc_handles` and `ipc_handles_pickled`" + ) + self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled)) + self.ipc_handles_pickled = None + + if self.ipc_handles is None: + raise ValueError( + "Either `ipc_handles` or `ipc_handles_pickled` must be provided" + ) + + num_params = len(self.names) + if len(self.dtype_names) != num_params: + raise ValueError( + f"`dtype_names` should be of the same size as `names`: " + f"got {len(self.dtype_names)} and {len(self.names)}" + ) + if len(self.shapes) != num_params: + raise ValueError( + f"`shapes` should be of the same size as `names`: " + f"got {len(self.shapes)} and {len(self.names)}" + ) + if len(self.ipc_handles) != num_params: + raise ValueError( + f"`ipc_handles` should be of the same size as `names`: " + f"got {len(self.ipc_handles)} and {len(self.names)}" + ) + + +class IPCWeightTransferEngine( + WeightTransferEngine[IPCWeightTransferInitInfo, IPCWeightTransferUpdateInfo] +): + """ + Weight transfer engine using CUDA IPC for communication between trainer and workers. + + This implementation uses CUDA IPC to transfer weights from the trainer (rank 0) + to all inference workers in a process group. IPC handles are used to share + memory between processes on the same node. + """ + + # Define backend-specific dataclass types + init_info_cls = IPCWeightTransferInitInfo + update_info_cls = IPCWeightTransferUpdateInfo + + def __init__( + self, config: WeightTransferConfig, parallel_config: ParallelConfig + ) -> None: + """ + Initialize the IPC weight transfer engine. + + Args: + config: The configuration for the weight transfer engine + parallel_config: The configuration for the parallel setup + """ + super().__init__(config, parallel_config) + + def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None: + """ + Initialize the weight transfer mechanism. + This is called once at the beginning of training. + No initialization needed for IPC backend. + + Args: + init_info: IPC initialization info (empty) + """ + pass + + def receive_weights( + self, + update_info: IPCWeightTransferUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """ + Receive weights from the trainer via CUDA IPC handles. + + Args: + update_info: IPC update info containing parameter names, dtypes, shapes, + and IPC handles. Each IPC handle is a mapping between physical + GPU UUID and the IPC handle tuple (func, args). + load_weights: Callable that loads weights into the model. Called + incrementally for each weight to avoid OOM. + """ + assert update_info.ipc_handles is not None + weights = [] + for name, _dtype_name, _shape, ipc_handle in zip( + update_info.names, + update_info.dtype_names, + update_info.shapes, + update_info.ipc_handles, + ): + device_index = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device_index) + physical_gpu_id = str(props.uuid) + + if physical_gpu_id not in ipc_handle: + raise ValueError( + f"IPC handle not found for GPU UUID {physical_gpu_id}. " + f"Available UUIDs: {list(ipc_handle.keys())}" + ) + + handle = ipc_handle[physical_gpu_id] + + func, args = handle + list_args = list(args) # type: ignore + # Index 6 is the device_index parameter in torch's + # IPC handle tuple (rebuild_cuda_tensor). Update it + # to the current device since the logical index can + # differ between sender and receiver. + list_args[6] = device_index + weight = func(*list_args) # type: ignore + weights.append((name, weight)) + + load_weights(weights) + + def shutdown(self) -> None: + """ + Shutdown the weight transfer engine. + """ + pass + + @staticmethod + def trainer_send_weights( + iterator: Iterator[tuple[str, torch.Tensor]], + trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs, + ) -> None: + """ + Send weights from trainer to inference workers via CUDA IPC. + + Supports two modes: + - 'ray': Sends weights via Ray RPC to a Ray-based LLM handle + - 'http': Sends weights via HTTP POST to a vLLM HTTP server + + Args: + iterator: Iterator of model parameters. Returns (name, tensor) tuples. + Tensors should be on the same GPU as the inference workers. + trainer_args: Dictionary containing IPC-specific arguments. + Should contain keys from IPCTrainerSendWeightsArgs: + - mode: 'ray' or 'http' + - llm_handle: Ray ObjectRef (for 'ray' mode) + - url: Base URL string (for 'http' mode) + + Example (Ray mode): + >>> from vllm.distributed.weight_transfer.ipc_engine import ( + ... IPCWeightTransferEngine, + ... IPCTrainerSendWeightsArgs, + ... ) + >>> param_iter = ((n, p) for n, p in model.named_parameters()) + >>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle) + >>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args)) + + Example (HTTP mode): + >>> args = IPCTrainerSendWeightsArgs( + ... mode="http", url="http://localhost:8000" + ... ) + >>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args)) + """ + # Parse trainer args - accept either dict or dataclass instance + if isinstance(trainer_args, dict): + args = IPCTrainerSendWeightsArgs(**trainer_args) + else: + args = trainer_args + + # Get physical GPU UUID + device_index = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device_index) + gpu_uuid = str(props.uuid) + + # Collect weight metadata and create IPC handles + names = [] + dtype_names = [] + shapes = [] + ipc_handles = [] + + for name, tensor in iterator: + names.append(name) + dtype_names.append(str(tensor.dtype).split(".")[-1]) + shapes.append(list(tensor.shape)) + + # Create IPC handle for this weight tensor + # The tensor must remain in memory for IPC to work + weight = tensor.detach().contiguous() + ipc_handle = reduce_tensor(weight) + ipc_handles.append({gpu_uuid: ipc_handle}) + + # Send weights based on mode + if args.mode == "ray": + # Ray mode: send via Ray RPC + import ray + + update_info = asdict( + IPCWeightTransferUpdateInfo( + names=names, + dtype_names=dtype_names, + shapes=shapes, + ipc_handles=ipc_handles, + ) + ) + ray.get( + args.llm_handle.update_weights.remote(dict(update_info=update_info)) + ) + elif args.mode == "http": + # HTTP mode: send via HTTP POST with pickled handles + # Pickle and base64 encode IPC handles for HTTP transmission + pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode( + "utf-8" + ) + + url = f"{args.url}/update_weights" + payload = { + "update_info": { + "names": names, + "dtype_names": dtype_names, + "shapes": shapes, + "ipc_handles_pickled": pickled_handles, + } + } + response = requests.post(url, json=payload, timeout=300) + response.raise_for_status() diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index 5c90198bf..e8a1091b9 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -35,6 +35,32 @@ class NCCLWeightTransferInitInfo(WeightTransferInitInfo): world_size: int +@dataclass +class NCCLTrainerSendWeightsArgs: + """Arguments for NCCL trainer_send_weights method.""" + + group: Any + """Process group (PyNcclCommunicator) for NCCL communication.""" + src: int = 0 + """Source rank (default 0, trainer is typically rank 0).""" + post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None + """Optional function to apply to each (name, tensor) pair before broadcasting. + If None, extracts just the tensor.""" + packed: bool = False + """Whether to use packed tensor broadcasting for efficiency. + When True, multiple tensors are batched together before broadcasting + to reduce NCCL communication overhead.""" + stream: torch.cuda.Stream | None = None + """CUDA stream to use for broadcasting if packed is False. + If packed is True, new streams will be created for each buffer.""" + packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES + """Size in bytes for each packed tensor buffer. + Must match the value used in NCCLWeightTransferUpdateInfo.""" + packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS + """Number of buffers for double/triple buffering during packed transfer. + Must match the value used in NCCLWeightTransferUpdateInfo.""" + + @dataclass class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo): """Update info for NCCL weight transfer backend.""" @@ -47,7 +73,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo): When True, multiple tensors are batched together before broadcasting to reduce NCCL communication overhead.""" packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES - """Size in bytes for each packed tensor buffer. Default is 1GB. + """Size in bytes for each packed tensor buffer. Both producer and consumer must use the same value.""" packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS """Number of buffers for double/triple buffering during packed transfer. @@ -186,47 +212,38 @@ class NCCLWeightTransferEngine( @staticmethod def trainer_send_weights( iterator: Iterator[tuple[str, torch.Tensor]], - group: Any, - src: int = 0, - post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] - | None = None, - packed: bool = False, - stream: torch.cuda.Stream | None = None, - packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES, - packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS, + trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs, ) -> None: """Broadcast weights from trainer to vLLM workers. Args: iterator: Iterator of model parameters. Returns (name, tensor) tuples - group: Process group (PyNcclCommunicator) - src: Source rank (default 0, trainer is typically rank 0) - post_iter_func: Optional function to apply to each (name, tensor) pair - before broadcasting. If None, extracts just the tensor. - packed: Whether to use packed tensor broadcasting for efficiency. - When True, multiple tensors are batched together before - broadcasting to reduce NCCL communication overhead. - stream: CUDA stream to use for broadcasting if packed is False. - If packed is True, new streams will be created for each buffer. - packed_buffer_size_bytes: Size in bytes for each packed tensor buffer. - Must match the value used in NCCLWeightTransferUpdateInfo. - packed_num_buffers: Number of buffers for double/triple buffering. - Must match the value used in NCCLWeightTransferUpdateInfo. + trainer_args: Dictionary or NCCLTrainerSendWeightsArgs instance containing + NCCL-specific arguments. If a dict, should contain keys from + NCCLTrainerSendWeightsArgs. Example: >>> from vllm.distributed.weight_transfer.nccl_engine import ( ... NCCLWeightTransferEngine, + ... NCCLTrainerSendWeightsArgs, ... ) >>> param_iter = ((n, p) for n, p in model.named_parameters()) - >>> NCCLWeightTransferEngine.trainer_send_weights( - ... param_iter, group, packed=True - ... ) + >>> args = NCCLTrainerSendWeightsArgs(group=group, packed=True) + >>> NCCLWeightTransferEngine.trainer_send_weights(param_iter, args) """ - if post_iter_func is None: + # Parse trainer args - accept either dict or dataclass instance + if isinstance(trainer_args, dict): + args = NCCLTrainerSendWeightsArgs(**trainer_args) + else: + args = trainer_args + + if args.post_iter_func is None: # Default: extract just the tensor from (name, tensor) tuple post_iter_func = lambda x: x[1] + else: + post_iter_func = args.post_iter_func - if packed: + if args.packed: # Use packed tensor broadcasting for efficiency from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_producer, @@ -234,18 +251,20 @@ class NCCLWeightTransferEngine( packed_broadcast_producer( iterator=iterator, - group=group, - src=src, + group=args.group, + src=args.src, post_iter_func=post_iter_func, - buffer_size_bytes=packed_buffer_size_bytes, - num_buffers=packed_num_buffers, + buffer_size_bytes=args.packed_buffer_size_bytes, + num_buffers=args.packed_num_buffers, ) else: # Use simple one-by-one broadcasting for item in iterator: tensor = post_iter_func(item) - group.broadcast( - tensor, src=src, stream=stream or torch.cuda.current_stream() + args.group.broadcast( + tensor, + src=args.src, + stream=args.stream or torch.cuda.current_stream(), ) @staticmethod