[Feat][RL][2/2] Native Weight Syncing API: IPC (#34171)
Signed-off-by: hao-aaron <ahao@anyscale.com> Signed-off-by: Aaron Hao <ahao@anyscale.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com>
This commit is contained in:
@@ -278,7 +278,8 @@ steps:
|
||||
- popd
|
||||
# NEW rlhf examples
|
||||
- pushd ../examples/offline_inference/new_weight_syncing
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
|
||||
- popd
|
||||
|
||||
|
||||
@@ -103,7 +103,8 @@ steps:
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
# NEW rlhf examples
|
||||
- cd new_weight_syncing
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
|
||||
|
||||
- label: Distributed Tests (8 GPUs)(H100)
|
||||
timeout_in_minutes: 10
|
||||
|
||||
@@ -42,6 +42,7 @@ from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferUpdateRequest,
|
||||
)
|
||||
from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
NCCLTrainerSendWeightsArgs,
|
||||
NCCLWeightTransferEngine,
|
||||
NCCLWeightTransferInitInfo,
|
||||
NCCLWeightTransferUpdateInfo,
|
||||
@@ -152,11 +153,14 @@ class TrainModel:
|
||||
|
||||
def broadcast_weights(self, packed: bool = True):
|
||||
"""Broadcast weights to the inference engine."""
|
||||
NCCLWeightTransferEngine.trainer_send_weights(
|
||||
iterator=self.model.named_parameters(),
|
||||
trainer_args = NCCLTrainerSendWeightsArgs(
|
||||
group=self.model_update_group,
|
||||
packed=packed,
|
||||
)
|
||||
NCCLWeightTransferEngine.trainer_send_weights(
|
||||
iterator=self.model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
|
||||
|
||||
149
examples/offline_inference/new_weight_syncing/rlhf_ipc.py
Normal file
149
examples/offline_inference/new_weight_syncing/rlhf_ipc.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray,
|
||||
with IPC-based weight syncing APIs
|
||||
|
||||
The script colocates the training and inference workloads onto the same GPU using Ray.
|
||||
|
||||
The example performs the following steps:
|
||||
|
||||
* Request a placement group of 1 GPU.
|
||||
* Place the inference model on the above GPU using the placement group.
|
||||
* Place and load the training model on the same GPU using the placement group.
|
||||
* Generate text from a list of prompts using the inference engine.
|
||||
* Update the weights of the training model and broadcast the updated weights
|
||||
to the inference engine by using CUDA IPC handles. Note that
|
||||
for demonstration purposes we simply zero out the weights.
|
||||
|
||||
This example assumes a single-node cluster with a single GPU,
|
||||
but can be extended to multiple GPUs.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
IPCTrainerSendWeightsArgs,
|
||||
IPCWeightTransferEngine,
|
||||
)
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""Configure the vLLM worker for Ray placement group execution."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
|
||||
# so that vLLM can manage its own device placement within the worker.
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
# Each worker uses 0.4 GPU so that two instances fit on the same GPU.
|
||||
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
|
||||
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
|
||||
# needed for ipc handle serialization
|
||||
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Load the OPT-125M model onto GPU 0 for the training workload.
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
|
||||
@ray.remote
|
||||
class TrainModel:
|
||||
def __init__(self, llm_handle: ray.actor.ActorHandle):
|
||||
self.train_model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
)
|
||||
self.train_model.to("cuda:0")
|
||||
self.llm_handle = llm_handle
|
||||
|
||||
def init_weight_transfer(self):
|
||||
# IPC backend doesn't need initialization info
|
||||
ray.get(
|
||||
self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict()))
|
||||
)
|
||||
|
||||
def broadcast_weights(self, llm_handle: ray.actor.ActorHandle):
|
||||
"""Broadcast weights to the inference engine using IPC."""
|
||||
self.llm_handle = llm_handle
|
||||
trainer_args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
iterator=self.train_model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
|
||||
ray.init()
|
||||
|
||||
pg_colocate = placement_group([{"GPU": 1, "CPU": 0}])
|
||||
ray.get(pg_colocate.ready())
|
||||
|
||||
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg_colocate,
|
||||
placement_group_capture_child_tasks=True,
|
||||
),
|
||||
)(MyLLM).remote(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
distributed_executor_backend="ray",
|
||||
gpu_memory_utilization=0.7,
|
||||
weight_transfer_config=WeightTransferConfig(backend="ipc"),
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
train_model = TrainModel.options(
|
||||
num_gpus=0.1,
|
||||
num_cpus=0,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg_colocate, placement_group_capture_child_tasks=True
|
||||
),
|
||||
).remote(llm)
|
||||
|
||||
|
||||
# Generate text from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
ray.get(llm.sleep.remote(level=0))
|
||||
|
||||
ray.get(train_model.init_weight_transfer.remote())
|
||||
# Synchronize the updated weights to the inference engine using batched API.
|
||||
ray.get(train_model.broadcast_weights.remote(llm))
|
||||
|
||||
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||
|
||||
# Generate text with the updated model.
|
||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
print("-" * 50)
|
||||
for output in outputs_updated:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
@@ -36,6 +36,7 @@ from transformers import AutoModelForCausalLM
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
NCCLTrainerSendWeightsArgs,
|
||||
NCCLWeightTransferEngine,
|
||||
)
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
@@ -90,11 +91,14 @@ class TrainModel:
|
||||
|
||||
def broadcast_weights(self, packed: bool = True):
|
||||
"""Broadcast weights to the inference engine."""
|
||||
NCCLWeightTransferEngine.trainer_send_weights(
|
||||
iterator=self.model.named_parameters(),
|
||||
trainer_args = NCCLTrainerSendWeightsArgs(
|
||||
group=self.model_update_group,
|
||||
packed=packed,
|
||||
)
|
||||
NCCLWeightTransferEngine.trainer_send_weights(
|
||||
iterator=self.model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
|
||||
# Initialize Ray and set the visible devices. The vLLM engine will
|
||||
@@ -156,6 +160,8 @@ for output in outputs:
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
ray.get(llm.sleep.remote(level=0))
|
||||
|
||||
# Set up the communication channel between the training process and the
|
||||
# inference engine.
|
||||
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
|
||||
@@ -197,6 +203,8 @@ inference_handle = llm.update_weights.remote(
|
||||
train_handle = train_model.broadcast_weights.remote(packed=True)
|
||||
ray.get([train_handle, inference_handle])
|
||||
|
||||
ray.get(llm.wake_up.remote(tags=["scheduling"]))
|
||||
|
||||
# Generate text with the updated model. The output is expected to be normal
|
||||
# because the weights are updated.
|
||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
181
examples/online_serving/new_weight_syncing/rlhf_http_ipc.py
Normal file
181
examples/online_serving/new_weight_syncing/rlhf_http_ipc.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
|
||||
via HTTP API, with IPC-based weight syncing APIs.
|
||||
|
||||
Unlike rlhf_nccl.py which uses NCCL and can use separate GPUs, this script
|
||||
uses CUDA IPC which requires the training model and vLLM server to be on the
|
||||
same GPU. Memory must be carefully managed to fit both models.
|
||||
|
||||
Unlike rlhf.py which creates a vLLM instance programmatically, this script
|
||||
assumes you have already started a vLLM server using `vllm serve`. It uses:
|
||||
- OpenAI-compatible API for inference requests
|
||||
- HTTP endpoints for weight transfer control plane
|
||||
- CUDA IPC for actual weight data transfer
|
||||
|
||||
Prerequisites:
|
||||
Start a vLLM server with weight transfer enabled and reduced GPU memory
|
||||
utilization to leave room for the training model:
|
||||
|
||||
$ VLLM_SERVER_DEV_MODE=1 VLLM_ALLOW_INSECURE_SERIALIZATION=1 \
|
||||
vllm serve facebook/opt-125m --enforce-eager \
|
||||
--weight-transfer-config '{"backend": "ipc"}' \
|
||||
--load-format dummy \
|
||||
--gpu-memory-utilization 0.5
|
||||
|
||||
Then run this script:
|
||||
|
||||
$ python rlhf_http_ipc.py
|
||||
|
||||
The example performs the following steps:
|
||||
|
||||
* Load the training model on GPU 0 (same GPU as the vLLM server).
|
||||
* Generate text using the vLLM server via OpenAI-compatible API. The output
|
||||
is expected to be nonsense because the server is initialized with dummy weights.
|
||||
* Initialize weight transfer via HTTP endpoint (no-op for IPC).
|
||||
* Broadcast the real weights from the training model to the vLLM server
|
||||
using CUDA IPC handles.
|
||||
* Generate text again to show normal output after the weight update.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from openai import OpenAI
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
IPCTrainerSendWeightsArgs,
|
||||
IPCWeightTransferEngine,
|
||||
)
|
||||
|
||||
BASE_URL = "http://localhost:8000"
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
# Enable insecure serialization for IPC handle serialization
|
||||
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
|
||||
|
||||
|
||||
def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
|
||||
"""Generate completions using the OpenAI-compatible API."""
|
||||
results = []
|
||||
for prompt in prompts:
|
||||
response = client.completions.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=32,
|
||||
temperature=0,
|
||||
)
|
||||
results.append(response.choices[0].text)
|
||||
return results
|
||||
|
||||
|
||||
def init_weight_transfer_engine(base_url: str) -> None:
|
||||
"""Initialize weight transfer via HTTP endpoint (no-op for IPC)."""
|
||||
url = f"{base_url}/init_weight_transfer_engine"
|
||||
payload = {"init_info": dict()}
|
||||
response = requests.post(url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def pause_generation(base_url: str) -> None:
|
||||
"""Pause generation via HTTP endpoint."""
|
||||
url = f"{base_url}/pause"
|
||||
response = requests.post(url, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def resume_generation(base_url: str) -> None:
|
||||
"""Resume generation via HTTP endpoint."""
|
||||
url = f"{base_url}/resume"
|
||||
response = requests.post(url, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def get_world_size(base_url: str) -> int:
|
||||
"""Get world size from the vLLM server."""
|
||||
url = f"{base_url}/get_world_size"
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.json()["world_size"]
|
||||
|
||||
|
||||
def main():
|
||||
# IPC requires the training model to be on the same GPU as the vLLM server
|
||||
# The server should be started on GPU 0 with reduced memory utilization
|
||||
device = "cuda:0"
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Load the training model on the same GPU as the server
|
||||
# Use bfloat16 to reduce memory footprint
|
||||
print(f"Loading training model: {MODEL_NAME} on {device}")
|
||||
print(
|
||||
"Note: Ensure the vLLM server was started with --gpu-memory-utilization 0.5 "
|
||||
"or lower to leave room for the training model."
|
||||
)
|
||||
train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
|
||||
train_model.to(device)
|
||||
train_model.eval() # Set to eval mode to save memory
|
||||
|
||||
# Create OpenAI client pointing to the vLLM server
|
||||
client = OpenAI(
|
||||
base_url=f"{BASE_URL}/v1",
|
||||
api_key="EMPTY", # vLLM doesn't require an API key by default
|
||||
)
|
||||
|
||||
# Test prompts
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Generate text before weight update. The output is expected to be nonsense
|
||||
# because the server is initialized with dummy weights.
|
||||
print("-" * 50)
|
||||
print("Generating text BEFORE weight update (expect nonsense):")
|
||||
print("-" * 50)
|
||||
outputs = generate_completions(client, MODEL_NAME, prompts)
|
||||
for prompt, generated_text in zip(prompts, outputs):
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
print("Initializing weight transfer (IPC backend)...")
|
||||
|
||||
# Initialize weight transfer on vLLM server (no-op for IPC, but still required)
|
||||
init_weight_transfer_engine(BASE_URL)
|
||||
|
||||
# Pause generation before weight sync
|
||||
pause_generation(BASE_URL)
|
||||
|
||||
# Broadcast weights via IPC handles using HTTP mode
|
||||
print("Broadcasting weights via CUDA IPC (HTTP)...")
|
||||
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
|
||||
IPCWeightTransferEngine.trainer_send_weights(
|
||||
iterator=train_model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
# Resume generation after weight sync
|
||||
resume_generation(BASE_URL)
|
||||
|
||||
# Generate text after weight update. The output is expected to be normal
|
||||
# because the real weights are now loaded.
|
||||
print("-" * 50)
|
||||
print("Generating text AFTER weight update:")
|
||||
print("-" * 50)
|
||||
outputs_updated = generate_completions(client, MODEL_NAME, prompts)
|
||||
for prompt, generated_text in zip(prompts, outputs_updated):
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Note: The training model and IPC handles remain in memory.
|
||||
# In a real RLHF training loop, you would update the training model
|
||||
# and create new IPC handles for each weight update.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -39,6 +39,7 @@ from openai import OpenAI
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
NCCLTrainerSendWeightsArgs,
|
||||
NCCLWeightTransferEngine,
|
||||
)
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
@@ -214,11 +215,14 @@ def main():
|
||||
|
||||
# Broadcast all weights from trainer to vLLM workers
|
||||
print("Broadcasting weights via NCCL...")
|
||||
NCCLWeightTransferEngine.trainer_send_weights(
|
||||
iterator=train_model.named_parameters(),
|
||||
trainer_args = NCCLTrainerSendWeightsArgs(
|
||||
group=model_update_group,
|
||||
packed=True,
|
||||
)
|
||||
NCCLWeightTransferEngine.trainer_send_weights(
|
||||
iterator=train_model.named_parameters(),
|
||||
trainer_args=trainer_args,
|
||||
)
|
||||
|
||||
# Wait for update_weights to complete
|
||||
update_thread.join()
|
||||
@@ -3,18 +3,26 @@
|
||||
"""Tests for weight transfer engine backends.
|
||||
|
||||
Unit tests for engine classes (parsing, validation, registry).
|
||||
Integration test for NCCL weight transfer between processes using Ray.
|
||||
Integration tests for NCCL and IPC weight transfer between processes using Ray.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import pickle
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
|
||||
from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
IPCWeightTransferEngine,
|
||||
IPCWeightTransferInitInfo,
|
||||
IPCWeightTransferUpdateInfo,
|
||||
)
|
||||
from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
NCCLWeightTransferEngine,
|
||||
NCCLWeightTransferInitInfo,
|
||||
@@ -155,9 +163,29 @@ class TestEngineRegistry:
|
||||
engine = WeightTransferEngineFactory.create_engine(config, parallel_config)
|
||||
assert isinstance(engine, NCCLWeightTransferEngine)
|
||||
|
||||
def test_create_engine_ipc(self):
|
||||
"""Test factory creates IPC engine."""
|
||||
config = WeightTransferConfig(backend="ipc")
|
||||
parallel_config = create_mock_parallel_config()
|
||||
engine = WeightTransferEngineFactory.create_engine(config, parallel_config)
|
||||
assert isinstance(engine, IPCWeightTransferEngine)
|
||||
|
||||
def test_create_engine_invalid_backend(self):
|
||||
"""Test factory raises for invalid backend."""
|
||||
config = WeightTransferConfig(backend="invalid")
|
||||
# Pydantic validates Literal types at construction, so we can't create
|
||||
# a config with an invalid backend. Instead, we test by directly
|
||||
# accessing the registry or using model_construct to bypass validation.
|
||||
from pydantic import ValidationError
|
||||
|
||||
# Test that Pydantic prevents invalid backend at construction
|
||||
with pytest.raises(ValidationError):
|
||||
WeightTransferConfig(backend="invalid")
|
||||
|
||||
# Test factory error by creating a config with valid backend but
|
||||
# then manually modifying the backend attribute (bypassing validation)
|
||||
config = WeightTransferConfig(backend="nccl")
|
||||
# Use object.__setattr__ to bypass Pydantic validation
|
||||
object.__setattr__(config, "backend", "invalid")
|
||||
parallel_config = create_mock_parallel_config()
|
||||
with pytest.raises(ValueError, match="Invalid weight transfer backend"):
|
||||
WeightTransferEngineFactory.create_engine(config, parallel_config)
|
||||
@@ -344,3 +372,426 @@ def test_nccl_weight_transfer_between_processes():
|
||||
f"Received shape: {result['received_shape']}, "
|
||||
f"Received sum: {result['received_sum']}"
|
||||
)
|
||||
|
||||
|
||||
# --- Unit Tests: IPCWeightTransferUpdateInfo Validation ---
|
||||
|
||||
|
||||
class TestIPCWeightTransferUpdateInfoValidation:
|
||||
"""Test IPCWeightTransferUpdateInfo dataclass validation."""
|
||||
|
||||
def test_valid_update_info(self):
|
||||
"""Test creating valid IPCWeightTransferUpdateInfo."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
# Create a dummy tensor and IPC handle
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}]
|
||||
|
||||
info = IPCWeightTransferUpdateInfo(
|
||||
names=["layer.weight"],
|
||||
dtype_names=["float32"],
|
||||
shapes=[[10, 10]],
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
assert info.names == ["layer.weight"]
|
||||
assert info.dtype_names == ["float32"]
|
||||
assert info.shapes == [[10, 10]]
|
||||
assert len(info.ipc_handles) == 1
|
||||
|
||||
def test_mismatched_dtype_names_raises(self):
|
||||
"""Test that mismatched dtype_names length raises ValueError."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}]
|
||||
|
||||
with pytest.raises(ValueError, match="dtype_names"):
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=["layer.weight", "layer.bias"],
|
||||
dtype_names=["float32"], # Only one dtype
|
||||
shapes=[[10, 10], [10]],
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
|
||||
def test_mismatched_shapes_raises(self):
|
||||
"""Test that mismatched shapes length raises ValueError."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}]
|
||||
|
||||
with pytest.raises(ValueError, match="shapes"):
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=["layer.weight", "layer.bias"],
|
||||
dtype_names=["float32", "float32"],
|
||||
shapes=[[10, 10]], # Only one shape
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
|
||||
def test_mismatched_ipc_handles_raises(self):
|
||||
"""Test that mismatched ipc_handles length raises ValueError."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}] # Only one handle
|
||||
|
||||
with pytest.raises(ValueError, match="ipc_handles"):
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=["layer.weight", "layer.bias"],
|
||||
dtype_names=["float32", "float32"],
|
||||
shapes=[[10, 10], [10]],
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
|
||||
def test_valid_update_info_from_pickled(self):
|
||||
"""Test creating IPCWeightTransferUpdateInfo from pickled handles."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}]
|
||||
|
||||
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
|
||||
|
||||
info = IPCWeightTransferUpdateInfo(
|
||||
names=["layer.weight"],
|
||||
dtype_names=["float32"],
|
||||
shapes=[[10, 10]],
|
||||
ipc_handles_pickled=pickled,
|
||||
)
|
||||
assert info.ipc_handles == ipc_handles
|
||||
assert info.ipc_handles_pickled is None
|
||||
|
||||
def test_both_handles_and_pickled_raises(self):
|
||||
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle}]
|
||||
|
||||
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot specify both"):
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=["layer.weight"],
|
||||
dtype_names=["float32"],
|
||||
shapes=[[10, 10]],
|
||||
ipc_handles=ipc_handles,
|
||||
ipc_handles_pickled=pickled,
|
||||
)
|
||||
|
||||
def test_neither_handles_nor_pickled_raises(self):
|
||||
"""Test that providing neither ipc_handles nor ipc_handles_pickled raises."""
|
||||
with pytest.raises(ValueError, match="must be provided"):
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=["layer.weight"],
|
||||
dtype_names=["float32"],
|
||||
shapes=[[10, 10]],
|
||||
)
|
||||
|
||||
def test_empty_lists_valid(self):
|
||||
"""Test that empty lists are valid."""
|
||||
info = IPCWeightTransferUpdateInfo(
|
||||
names=[],
|
||||
dtype_names=[],
|
||||
shapes=[],
|
||||
ipc_handles=[],
|
||||
)
|
||||
assert len(info.names) == 0
|
||||
|
||||
|
||||
# --- Unit Tests: IPC Engine Parsing ---
|
||||
|
||||
|
||||
class TestIPCEngineParsing:
|
||||
"""Test IPCWeightTransferEngine parsing methods."""
|
||||
|
||||
def test_parse_update_info_valid(self):
|
||||
"""Test parsing valid update info dict."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
config = WeightTransferConfig(backend="ipc")
|
||||
parallel_config = create_mock_parallel_config()
|
||||
engine = IPCWeightTransferEngine(config, parallel_config)
|
||||
|
||||
# Create dummy IPC handles
|
||||
dummy_tensor1 = torch.ones(100, 100, device="cuda:0")
|
||||
dummy_tensor2 = torch.ones(50, device="cuda:0")
|
||||
ipc_handle1 = reduce_tensor(dummy_tensor1)
|
||||
ipc_handle2 = reduce_tensor(dummy_tensor2)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}]
|
||||
|
||||
update_info = engine.parse_update_info(
|
||||
{
|
||||
"names": ["w1", "w2"],
|
||||
"dtype_names": ["float32", "bfloat16"],
|
||||
"shapes": [[100, 100], [50]],
|
||||
"ipc_handles": ipc_handles,
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(update_info, IPCWeightTransferUpdateInfo)
|
||||
assert update_info.names == ["w1", "w2"]
|
||||
assert update_info.dtype_names == ["float32", "bfloat16"]
|
||||
assert update_info.shapes == [[100, 100], [50]]
|
||||
assert len(update_info.ipc_handles) == 2
|
||||
|
||||
def test_parse_update_info_pickled(self):
|
||||
"""Test parsing update info with pickled IPC handles (HTTP path)."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
config = WeightTransferConfig(backend="ipc")
|
||||
parallel_config = create_mock_parallel_config()
|
||||
engine = IPCWeightTransferEngine(config, parallel_config)
|
||||
|
||||
dummy_tensor1 = torch.ones(100, 100, device="cuda:0")
|
||||
dummy_tensor2 = torch.ones(50, device="cuda:0")
|
||||
ipc_handle1 = reduce_tensor(dummy_tensor1)
|
||||
ipc_handle2 = reduce_tensor(dummy_tensor2)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}]
|
||||
|
||||
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
|
||||
|
||||
update_info = engine.parse_update_info(
|
||||
{
|
||||
"names": ["w1", "w2"],
|
||||
"dtype_names": ["float32", "bfloat16"],
|
||||
"shapes": [[100, 100], [50]],
|
||||
"ipc_handles_pickled": pickled,
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(update_info, IPCWeightTransferUpdateInfo)
|
||||
assert update_info.names == ["w1", "w2"]
|
||||
assert len(update_info.ipc_handles) == 2
|
||||
assert update_info.ipc_handles_pickled is None
|
||||
assert gpu_uuid in update_info.ipc_handles[0]
|
||||
assert gpu_uuid in update_info.ipc_handles[1]
|
||||
|
||||
|
||||
# --- Integration Test: IPC Weight Transfer Between Ray Tasks ---
|
||||
|
||||
|
||||
def get_physical_gpu_id(device_index: int = 0) -> str:
|
||||
"""Get physical GPU UUID for a device."""
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
return str(props.uuid)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=0.5)
|
||||
class TrainerActor:
|
||||
"""Trainer actor that creates and holds CUDA IPC handles."""
|
||||
|
||||
def __init__(self, tensor_shape: list[int], tensor_dtype: str):
|
||||
# Create tensor on GPU and keep it alive
|
||||
dtype = getattr(torch, tensor_dtype)
|
||||
self.tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda:0")
|
||||
self.tensor.fill_(42.0) # Fill with 42 to verify correct transfer
|
||||
|
||||
# Create IPC handle (tensor must stay alive for IPC to work)
|
||||
ipc_handle = reduce_tensor(self.tensor)
|
||||
gpu_uuid = get_physical_gpu_id(0)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
self.ipc_handle_dict = {
|
||||
"ipc_handle": ipc_handle,
|
||||
"gpu_uuid": gpu_uuid,
|
||||
"shape": tensor_shape,
|
||||
"dtype": tensor_dtype,
|
||||
}
|
||||
|
||||
def get_ipc_handle_dict(self) -> dict:
|
||||
"""Return IPC handle dict. Tensor stays alive in this actor."""
|
||||
return self.ipc_handle_dict
|
||||
|
||||
|
||||
@ray.remote(num_gpus=0.5)
|
||||
def inference_receive_ipc_tensor(
|
||||
ipc_handle_dict: dict,
|
||||
mode: str = "ray",
|
||||
) -> dict:
|
||||
"""Inference task that receives tensor via IPCWeightTransferEngine."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
IPCWeightTransferEngine,
|
||||
)
|
||||
|
||||
# Create engine with mock parallel config
|
||||
config = WeightTransferConfig(backend="ipc")
|
||||
parallel_config = MagicMock(spec=ParallelConfig)
|
||||
parallel_config.rank = 0
|
||||
parallel_config.world_size = 1
|
||||
parallel_config.data_parallel_rank = 0
|
||||
|
||||
engine = IPCWeightTransferEngine(config, parallel_config)
|
||||
|
||||
# Initialize the engine (no-op for IPC)
|
||||
init_info = IPCWeightTransferInitInfo()
|
||||
engine.init_transfer_engine(init_info)
|
||||
|
||||
# Receive weights with a no-op load_weights that captures the tensor
|
||||
received_tensors = []
|
||||
|
||||
def noop_load_weights(weights: list[tuple[str, torch.Tensor]]):
|
||||
for name, tensor in weights:
|
||||
# Clone tensor to keep it after engine cleans up
|
||||
received_tensors.append((name, tensor.clone()))
|
||||
|
||||
# Build update dict and go through parse_update_info (exercises __post_init__)
|
||||
ipc_handles = [{ipc_handle_dict["gpu_uuid"]: ipc_handle_dict["ipc_handle"]}]
|
||||
|
||||
if mode == "ray":
|
||||
update_dict: dict = {
|
||||
"names": ["test.weight"],
|
||||
"dtype_names": [ipc_handle_dict["dtype"]],
|
||||
"shapes": [ipc_handle_dict["shape"]],
|
||||
"ipc_handles": ipc_handles,
|
||||
}
|
||||
elif mode == "http":
|
||||
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
|
||||
update_dict = {
|
||||
"names": ["test.weight"],
|
||||
"dtype_names": [ipc_handle_dict["dtype"]],
|
||||
"shapes": [ipc_handle_dict["shape"]],
|
||||
"ipc_handles_pickled": pickled,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown mode: {mode}")
|
||||
|
||||
update_info = engine.parse_update_info(update_dict)
|
||||
engine.receive_weights(update_info, noop_load_weights)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Verify we received the tensor
|
||||
success = False
|
||||
received_shape = None
|
||||
received_sum = None
|
||||
|
||||
if len(received_tensors) == 1:
|
||||
name, tensor = received_tensors[0]
|
||||
received_shape = list(tensor.shape)
|
||||
received_sum = tensor.sum().item()
|
||||
# Check shape matches and values are all 42s (trainer sends 42s)
|
||||
if received_shape == ipc_handle_dict["shape"]:
|
||||
expected_sum = 42.0 * torch.tensor(ipc_handle_dict["shape"]).prod().item()
|
||||
if abs(received_sum - expected_sum) < 0.01:
|
||||
success = True
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"received_shape": received_shape,
|
||||
"received_sum": received_sum,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 1,
|
||||
reason="Need at least 1 GPU to run IPC weight transfer test.",
|
||||
)
|
||||
@pytest.mark.parametrize("mode", ["ray", "http"])
|
||||
def test_ipc_weight_transfer_between_processes(mode: str):
|
||||
"""Test IPC weight transfer from trainer to inference process using Ray.
|
||||
|
||||
Parametrized over transport modes:
|
||||
- 'ray': ipc_handles passed directly.
|
||||
- 'http': ipc_handles pickled + base64-encoded, unpickled via __post_init__.
|
||||
|
||||
IPC requires same-GPU access, so we use a placement group to co-locate
|
||||
the trainer actor and inference task on the same GPU.
|
||||
"""
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
ray.init(ignore_reinit_error=True)
|
||||
|
||||
# Create a placement group to ensure both processes are on the same GPU
|
||||
# Use fractional GPUs so both tasks can share the same GPU bundle
|
||||
pg = placement_group([{"GPU": 1, "CPU": 2}])
|
||||
ray.get(pg.ready())
|
||||
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
)
|
||||
|
||||
# Tensor to transfer: 100x100 filled with 42s
|
||||
tensor_shape = [100, 100]
|
||||
tensor_dtype = "float32"
|
||||
|
||||
# Create trainer actor that holds the tensor and IPC handle (stays alive)
|
||||
trainer_actor = TrainerActor.options( # type: ignore[attr-defined]
|
||||
scheduling_strategy=scheduling_strategy
|
||||
).remote(tensor_shape, tensor_dtype)
|
||||
|
||||
# Get IPC handle dict (tensor stays alive in trainer actor)
|
||||
ipc_handle_dict = ray.get(trainer_actor.get_ipc_handle_dict.remote())
|
||||
|
||||
# Receive tensor in inference process using IPC handles (on same GPU)
|
||||
# Trainer actor stays alive during this operation
|
||||
inference_result = ray.get(
|
||||
inference_receive_ipc_tensor.options(
|
||||
scheduling_strategy=scheduling_strategy
|
||||
).remote(ipc_handle_dict, mode=mode)
|
||||
)
|
||||
|
||||
assert inference_result["success"], (
|
||||
f"IPC weight transfer failed (mode={mode}). "
|
||||
f"Received shape: {inference_result['received_shape']}, "
|
||||
f"Received sum: {inference_result['received_sum']}"
|
||||
)
|
||||
|
||||
|
||||
def test_ipc_receive_weights_missing_gpu_uuid_raises():
|
||||
"""Test that receive_weights raises if GPU UUID not found in IPC handles."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
config = WeightTransferConfig(backend="ipc")
|
||||
parallel_config = create_mock_parallel_config()
|
||||
engine = IPCWeightTransferEngine(config, parallel_config)
|
||||
|
||||
# Create IPC handle with wrong GPU UUID
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
wrong_uuid = "wrong-uuid-12345"
|
||||
ipc_handles = [{wrong_uuid: ipc_handle}]
|
||||
|
||||
update_info = IPCWeightTransferUpdateInfo(
|
||||
names=["w"],
|
||||
dtype_names=["float32"],
|
||||
shapes=[[10, 10]],
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="IPC handle not found"):
|
||||
engine.receive_weights(update_info, lambda x: None)
|
||||
|
||||
@@ -37,6 +37,8 @@ CHECK_IMPORTS = {
|
||||
"vllm/distributed/device_communicators/all_reduce_utils.py",
|
||||
"vllm/distributed/device_communicators/shm_broadcast.py",
|
||||
"vllm/distributed/device_communicators/shm_object_storage.py",
|
||||
"vllm/distributed/weight_transfer/ipc_engine.py",
|
||||
"tests/distributed/test_weight_transfer.py",
|
||||
"vllm/utils/hashing.py",
|
||||
"tests/multimodal/media/test_base.py",
|
||||
"tests/tokenizers_/test_hf.py",
|
||||
|
||||
@@ -9,5 +9,5 @@ from vllm.config.utils import config
|
||||
class WeightTransferConfig:
|
||||
"""Configuration for weight transfer during RL training."""
|
||||
|
||||
backend: Literal["nccl"] = "nccl"
|
||||
backend: Literal["nccl", "ipc"] = "nccl"
|
||||
"""The backend to use for weight transfer."""
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""Base class for weight transfer engines."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import KW_ONLY, dataclass, field
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
@@ -156,3 +156,30 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
|
||||
This should be called when the worker is shutting down.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
trainer_args: dict[str, Any] | Any,
|
||||
) -> None:
|
||||
"""
|
||||
Send weights from trainer to inference workers.
|
||||
|
||||
This is a static method that can be called from the trainer process
|
||||
to send weights to all inference workers.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
|
||||
The tensors should be on the appropriate device for the backend.
|
||||
trainer_args: Dictionary containing backend-specific arguments needed
|
||||
to send weights. The structure depends on the backend:
|
||||
- NCCL: Contains 'group', 'src', 'packed', etc.
|
||||
- IPC: Contains 'mode' ('http' or 'ray'),
|
||||
'llm_handle' (for Ray), 'url' (for HTTP), etc.
|
||||
|
||||
Example:
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> engine.trainer_send_weights(param_iter, trainer_args)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -114,3 +114,9 @@ WeightTransferEngineFactory.register_engine(
|
||||
"vllm.distributed.weight_transfer.nccl_engine",
|
||||
"NCCLWeightTransferEngine",
|
||||
)
|
||||
|
||||
WeightTransferEngineFactory.register_engine(
|
||||
"ipc",
|
||||
"vllm.distributed.weight_transfer.ipc_engine",
|
||||
"IPCWeightTransferEngine",
|
||||
)
|
||||
|
||||
291
vllm/distributed/weight_transfer/ipc_engine.py
Normal file
291
vllm/distributed/weight_transfer/ipc_engine.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""IPC-based weight transfer engine using CUDA IPC for communication."""
|
||||
|
||||
import base64
|
||||
import pickle
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferEngine,
|
||||
WeightTransferInitInfo,
|
||||
WeightTransferUpdateInfo,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCTrainerSendWeightsArgs:
|
||||
"""Arguments for IPC trainer_send_weights method."""
|
||||
|
||||
mode: str
|
||||
"""Transport mode: 'http' or 'ray'."""
|
||||
llm_handle: Any = None
|
||||
"""Ray ObjectRef to LLM handle (required for 'ray' mode)."""
|
||||
url: str | None = None
|
||||
"""Base URL for HTTP endpoint (required for 'http' mode)."""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that required arguments are provided for the selected mode."""
|
||||
if self.mode == "ray" and self.llm_handle is None:
|
||||
raise ValueError("llm_handle is required for 'ray' mode")
|
||||
if self.mode == "http" and self.url is None:
|
||||
raise ValueError("url is required for 'http' mode")
|
||||
if self.mode not in ("ray", "http"):
|
||||
raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCWeightTransferInitInfo(WeightTransferInitInfo):
|
||||
"""Initialization info for IPC weight transfer backend. No init needed for IPC."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
"""Update info for IPC weight transfer backend.
|
||||
|
||||
Accepts IPC handles either directly via ``ipc_handles`` (Ray transport)
|
||||
or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport).
|
||||
Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set
|
||||
it is unpickled into ``ipc_handles`` during ``__post_init__``.
|
||||
"""
|
||||
|
||||
names: list[str]
|
||||
dtype_names: list[str]
|
||||
shapes: list[list[int]]
|
||||
ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None
|
||||
"""IPC handles mapping physical GPU UUID to (func, args) tuple.
|
||||
Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples."""
|
||||
ipc_handles_pickled: str | None = None
|
||||
"""Base64-encoded pickled IPC handles, used for HTTP transport."""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ipc_handles_pickled is not None:
|
||||
if self.ipc_handles is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
|
||||
)
|
||||
self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
|
||||
self.ipc_handles_pickled = None
|
||||
|
||||
if self.ipc_handles is None:
|
||||
raise ValueError(
|
||||
"Either `ipc_handles` or `ipc_handles_pickled` must be provided"
|
||||
)
|
||||
|
||||
num_params = len(self.names)
|
||||
if len(self.dtype_names) != num_params:
|
||||
raise ValueError(
|
||||
f"`dtype_names` should be of the same size as `names`: "
|
||||
f"got {len(self.dtype_names)} and {len(self.names)}"
|
||||
)
|
||||
if len(self.shapes) != num_params:
|
||||
raise ValueError(
|
||||
f"`shapes` should be of the same size as `names`: "
|
||||
f"got {len(self.shapes)} and {len(self.names)}"
|
||||
)
|
||||
if len(self.ipc_handles) != num_params:
|
||||
raise ValueError(
|
||||
f"`ipc_handles` should be of the same size as `names`: "
|
||||
f"got {len(self.ipc_handles)} and {len(self.names)}"
|
||||
)
|
||||
|
||||
|
||||
class IPCWeightTransferEngine(
|
||||
WeightTransferEngine[IPCWeightTransferInitInfo, IPCWeightTransferUpdateInfo]
|
||||
):
|
||||
"""
|
||||
Weight transfer engine using CUDA IPC for communication between trainer and workers.
|
||||
|
||||
This implementation uses CUDA IPC to transfer weights from the trainer (rank 0)
|
||||
to all inference workers in a process group. IPC handles are used to share
|
||||
memory between processes on the same node.
|
||||
"""
|
||||
|
||||
# Define backend-specific dataclass types
|
||||
init_info_cls = IPCWeightTransferInitInfo
|
||||
update_info_cls = IPCWeightTransferUpdateInfo
|
||||
|
||||
def __init__(
|
||||
self, config: WeightTransferConfig, parallel_config: ParallelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the IPC weight transfer engine.
|
||||
|
||||
Args:
|
||||
config: The configuration for the weight transfer engine
|
||||
parallel_config: The configuration for the parallel setup
|
||||
"""
|
||||
super().__init__(config, parallel_config)
|
||||
|
||||
def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None:
|
||||
"""
|
||||
Initialize the weight transfer mechanism.
|
||||
This is called once at the beginning of training.
|
||||
No initialization needed for IPC backend.
|
||||
|
||||
Args:
|
||||
init_info: IPC initialization info (empty)
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive_weights(
|
||||
self,
|
||||
update_info: IPCWeightTransferUpdateInfo,
|
||||
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
|
||||
) -> None:
|
||||
"""
|
||||
Receive weights from the trainer via CUDA IPC handles.
|
||||
|
||||
Args:
|
||||
update_info: IPC update info containing parameter names, dtypes, shapes,
|
||||
and IPC handles. Each IPC handle is a mapping between physical
|
||||
GPU UUID and the IPC handle tuple (func, args).
|
||||
load_weights: Callable that loads weights into the model. Called
|
||||
incrementally for each weight to avoid OOM.
|
||||
"""
|
||||
assert update_info.ipc_handles is not None
|
||||
weights = []
|
||||
for name, _dtype_name, _shape, ipc_handle in zip(
|
||||
update_info.names,
|
||||
update_info.dtype_names,
|
||||
update_info.shapes,
|
||||
update_info.ipc_handles,
|
||||
):
|
||||
device_index = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
physical_gpu_id = str(props.uuid)
|
||||
|
||||
if physical_gpu_id not in ipc_handle:
|
||||
raise ValueError(
|
||||
f"IPC handle not found for GPU UUID {physical_gpu_id}. "
|
||||
f"Available UUIDs: {list(ipc_handle.keys())}"
|
||||
)
|
||||
|
||||
handle = ipc_handle[physical_gpu_id]
|
||||
|
||||
func, args = handle
|
||||
list_args = list(args) # type: ignore
|
||||
# Index 6 is the device_index parameter in torch's
|
||||
# IPC handle tuple (rebuild_cuda_tensor). Update it
|
||||
# to the current device since the logical index can
|
||||
# differ between sender and receiver.
|
||||
list_args[6] = device_index
|
||||
weight = func(*list_args) # type: ignore
|
||||
weights.append((name, weight))
|
||||
|
||||
load_weights(weights)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the weight transfer engine.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs,
|
||||
) -> None:
|
||||
"""
|
||||
Send weights from trainer to inference workers via CUDA IPC.
|
||||
|
||||
Supports two modes:
|
||||
- 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
|
||||
- 'http': Sends weights via HTTP POST to a vLLM HTTP server
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
|
||||
Tensors should be on the same GPU as the inference workers.
|
||||
trainer_args: Dictionary containing IPC-specific arguments.
|
||||
Should contain keys from IPCTrainerSendWeightsArgs:
|
||||
- mode: 'ray' or 'http'
|
||||
- llm_handle: Ray ObjectRef (for 'ray' mode)
|
||||
- url: Base URL string (for 'http' mode)
|
||||
|
||||
Example (Ray mode):
|
||||
>>> from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
... IPCWeightTransferEngine,
|
||||
... IPCTrainerSendWeightsArgs,
|
||||
... )
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
|
||||
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
|
||||
|
||||
Example (HTTP mode):
|
||||
>>> args = IPCTrainerSendWeightsArgs(
|
||||
... mode="http", url="http://localhost:8000"
|
||||
... )
|
||||
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
|
||||
"""
|
||||
# Parse trainer args - accept either dict or dataclass instance
|
||||
if isinstance(trainer_args, dict):
|
||||
args = IPCTrainerSendWeightsArgs(**trainer_args)
|
||||
else:
|
||||
args = trainer_args
|
||||
|
||||
# Get physical GPU UUID
|
||||
device_index = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
gpu_uuid = str(props.uuid)
|
||||
|
||||
# Collect weight metadata and create IPC handles
|
||||
names = []
|
||||
dtype_names = []
|
||||
shapes = []
|
||||
ipc_handles = []
|
||||
|
||||
for name, tensor in iterator:
|
||||
names.append(name)
|
||||
dtype_names.append(str(tensor.dtype).split(".")[-1])
|
||||
shapes.append(list(tensor.shape))
|
||||
|
||||
# Create IPC handle for this weight tensor
|
||||
# The tensor must remain in memory for IPC to work
|
||||
weight = tensor.detach().contiguous()
|
||||
ipc_handle = reduce_tensor(weight)
|
||||
ipc_handles.append({gpu_uuid: ipc_handle})
|
||||
|
||||
# Send weights based on mode
|
||||
if args.mode == "ray":
|
||||
# Ray mode: send via Ray RPC
|
||||
import ray
|
||||
|
||||
update_info = asdict(
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=names,
|
||||
dtype_names=dtype_names,
|
||||
shapes=shapes,
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
)
|
||||
ray.get(
|
||||
args.llm_handle.update_weights.remote(dict(update_info=update_info))
|
||||
)
|
||||
elif args.mode == "http":
|
||||
# HTTP mode: send via HTTP POST with pickled handles
|
||||
# Pickle and base64 encode IPC handles for HTTP transmission
|
||||
pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
url = f"{args.url}/update_weights"
|
||||
payload = {
|
||||
"update_info": {
|
||||
"names": names,
|
||||
"dtype_names": dtype_names,
|
||||
"shapes": shapes,
|
||||
"ipc_handles_pickled": pickled_handles,
|
||||
}
|
||||
}
|
||||
response = requests.post(url, json=payload, timeout=300)
|
||||
response.raise_for_status()
|
||||
@@ -35,6 +35,32 @@ class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
|
||||
world_size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class NCCLTrainerSendWeightsArgs:
|
||||
"""Arguments for NCCL trainer_send_weights method."""
|
||||
|
||||
group: Any
|
||||
"""Process group (PyNcclCommunicator) for NCCL communication."""
|
||||
src: int = 0
|
||||
"""Source rank (default 0, trainer is typically rank 0)."""
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None
|
||||
"""Optional function to apply to each (name, tensor) pair before broadcasting.
|
||||
If None, extracts just the tensor."""
|
||||
packed: bool = False
|
||||
"""Whether to use packed tensor broadcasting for efficiency.
|
||||
When True, multiple tensors are batched together before broadcasting
|
||||
to reduce NCCL communication overhead."""
|
||||
stream: torch.cuda.Stream | None = None
|
||||
"""CUDA stream to use for broadcasting if packed is False.
|
||||
If packed is True, new streams will be created for each buffer."""
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
||||
"""Size in bytes for each packed tensor buffer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo."""
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
|
||||
"""Number of buffers for double/triple buffering during packed transfer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
"""Update info for NCCL weight transfer backend."""
|
||||
@@ -47,7 +73,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
When True, multiple tensors are batched together before broadcasting
|
||||
to reduce NCCL communication overhead."""
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
||||
"""Size in bytes for each packed tensor buffer. Default is 1GB.
|
||||
"""Size in bytes for each packed tensor buffer.
|
||||
Both producer and consumer must use the same value."""
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
|
||||
"""Number of buffers for double/triple buffering during packed transfer.
|
||||
@@ -186,47 +212,38 @@ class NCCLWeightTransferEngine(
|
||||
@staticmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
group: Any,
|
||||
src: int = 0,
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
|
||||
| None = None,
|
||||
packed: bool = False,
|
||||
stream: torch.cuda.Stream | None = None,
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
|
||||
trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs,
|
||||
) -> None:
|
||||
"""Broadcast weights from trainer to vLLM workers.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples
|
||||
group: Process group (PyNcclCommunicator)
|
||||
src: Source rank (default 0, trainer is typically rank 0)
|
||||
post_iter_func: Optional function to apply to each (name, tensor) pair
|
||||
before broadcasting. If None, extracts just the tensor.
|
||||
packed: Whether to use packed tensor broadcasting for efficiency.
|
||||
When True, multiple tensors are batched together before
|
||||
broadcasting to reduce NCCL communication overhead.
|
||||
stream: CUDA stream to use for broadcasting if packed is False.
|
||||
If packed is True, new streams will be created for each buffer.
|
||||
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo.
|
||||
packed_num_buffers: Number of buffers for double/triple buffering.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo.
|
||||
trainer_args: Dictionary or NCCLTrainerSendWeightsArgs instance containing
|
||||
NCCL-specific arguments. If a dict, should contain keys from
|
||||
NCCLTrainerSendWeightsArgs.
|
||||
|
||||
Example:
|
||||
>>> from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
... NCCLWeightTransferEngine,
|
||||
... NCCLTrainerSendWeightsArgs,
|
||||
... )
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> NCCLWeightTransferEngine.trainer_send_weights(
|
||||
... param_iter, group, packed=True
|
||||
... )
|
||||
>>> args = NCCLTrainerSendWeightsArgs(group=group, packed=True)
|
||||
>>> NCCLWeightTransferEngine.trainer_send_weights(param_iter, args)
|
||||
"""
|
||||
if post_iter_func is None:
|
||||
# Parse trainer args - accept either dict or dataclass instance
|
||||
if isinstance(trainer_args, dict):
|
||||
args = NCCLTrainerSendWeightsArgs(**trainer_args)
|
||||
else:
|
||||
args = trainer_args
|
||||
|
||||
if args.post_iter_func is None:
|
||||
# Default: extract just the tensor from (name, tensor) tuple
|
||||
post_iter_func = lambda x: x[1]
|
||||
else:
|
||||
post_iter_func = args.post_iter_func
|
||||
|
||||
if packed:
|
||||
if args.packed:
|
||||
# Use packed tensor broadcasting for efficiency
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
packed_broadcast_producer,
|
||||
@@ -234,18 +251,20 @@ class NCCLWeightTransferEngine(
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iterator,
|
||||
group=group,
|
||||
src=src,
|
||||
group=args.group,
|
||||
src=args.src,
|
||||
post_iter_func=post_iter_func,
|
||||
buffer_size_bytes=packed_buffer_size_bytes,
|
||||
num_buffers=packed_num_buffers,
|
||||
buffer_size_bytes=args.packed_buffer_size_bytes,
|
||||
num_buffers=args.packed_num_buffers,
|
||||
)
|
||||
else:
|
||||
# Use simple one-by-one broadcasting
|
||||
for item in iterator:
|
||||
tensor = post_iter_func(item)
|
||||
group.broadcast(
|
||||
tensor, src=src, stream=stream or torch.cuda.current_stream()
|
||||
args.group.broadcast(
|
||||
tensor,
|
||||
src=args.src,
|
||||
stream=args.stream or torch.cuda.current_stream(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user