[Feat][RL][2/2] Native Weight Syncing API: IPC (#34171)

Signed-off-by: hao-aaron <ahao@anyscale.com>
Signed-off-by: Aaron Hao <ahao@anyscale.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
This commit is contained in:
Aaron Hao
2026-02-27 12:45:21 -08:00
committed by GitHub
parent 1f3dbd95fd
commit 2ce6f3cf67
14 changed files with 1189 additions and 45 deletions

View File

@@ -278,7 +278,8 @@ steps:
- popd
# NEW rlhf examples
- pushd ../examples/offline_inference/new_weight_syncing
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
- popd

View File

@@ -103,7 +103,8 @@ steps:
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
# NEW rlhf examples
- cd new_weight_syncing
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
- label: Distributed Tests (8 GPUs)(H100)
timeout_in_minutes: 10

View File

@@ -42,6 +42,7 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferUpdateRequest,
)
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
NCCLWeightTransferInitInfo,
NCCLWeightTransferUpdateInfo,
@@ -152,11 +153,14 @@ class TrainModel:
def broadcast_weights(self, packed: bool = True):
"""Broadcast weights to the inference engine."""
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args = NCCLTrainerSendWeightsArgs(
group=self.model_update_group,
packed=packed,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args=trainer_args,
)
@torch.inference_mode()
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:

View File

@@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray,
with IPC-based weight syncing APIs
The script colocates the training and inference workloads onto the same GPU using Ray.
The example performs the following steps:
* Request a placement group of 1 GPU.
* Place the inference model on the above GPU using the placement group.
* Place and load the training model on the same GPU using the placement group.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using CUDA IPC handles. Note that
for demonstration purposes we simply zero out the weights.
This example assumes a single-node cluster with a single GPU,
but can be extended to multiple GPUs.
"""
import os
import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.ipc_engine import (
IPCTrainerSendWeightsArgs,
IPCWeightTransferEngine,
)
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# so that vLLM can manage its own device placement within the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
# Each worker uses 0.4 GPU so that two instances fit on the same GPU.
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
# needed for ipc handle serialization
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
super().__init__(*args, **kwargs)
# Load the OPT-125M model onto GPU 0 for the training workload.
MODEL_NAME = "facebook/opt-125m"
@ray.remote
class TrainModel:
def __init__(self, llm_handle: ray.actor.ActorHandle):
self.train_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
)
self.train_model.to("cuda:0")
self.llm_handle = llm_handle
def init_weight_transfer(self):
# IPC backend doesn't need initialization info
ray.get(
self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict()))
)
def broadcast_weights(self, llm_handle: ray.actor.ActorHandle):
"""Broadcast weights to the inference engine using IPC."""
self.llm_handle = llm_handle
trainer_args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
IPCWeightTransferEngine.trainer_send_weights(
iterator=self.train_model.named_parameters(),
trainer_args=trainer_args,
)
ray.init()
pg_colocate = placement_group([{"GPU": 1, "CPU": 0}])
ray.get(pg_colocate.ready())
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg_colocate,
placement_group_capture_child_tasks=True,
),
)(MyLLM).remote(
model=MODEL_NAME,
enforce_eager=True,
tensor_parallel_size=1,
distributed_executor_backend="ray",
gpu_memory_utilization=0.7,
weight_transfer_config=WeightTransferConfig(backend="ipc"),
load_format="dummy",
)
train_model = TrainModel.options(
num_gpus=0.1,
num_cpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg_colocate, placement_group_capture_child_tasks=True
),
).remote(llm)
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
ray.get(llm.sleep.remote(level=0))
ray.get(train_model.init_weight_transfer.remote())
# Synchronize the updated weights to the inference engine using batched API.
ray.get(train_model.broadcast_weights.remote(llm))
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)

View File

@@ -36,6 +36,7 @@ from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
@@ -90,11 +91,14 @@ class TrainModel:
def broadcast_weights(self, packed: bool = True):
"""Broadcast weights to the inference engine."""
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args = NCCLTrainerSendWeightsArgs(
group=self.model_update_group,
packed=packed,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args=trainer_args,
)
# Initialize Ray and set the visible devices. The vLLM engine will
@@ -156,6 +160,8 @@ for output in outputs:
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
ray.get(llm.sleep.remote(level=0))
# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
@@ -197,6 +203,8 @@ inference_handle = llm.update_weights.remote(
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model. The output is expected to be normal
# because the weights are updated.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))

View File

@@ -0,0 +1,181 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
via HTTP API, with IPC-based weight syncing APIs.
Unlike rlhf_nccl.py which uses NCCL and can use separate GPUs, this script
uses CUDA IPC which requires the training model and vLLM server to be on the
same GPU. Memory must be carefully managed to fit both models.
Unlike rlhf.py which creates a vLLM instance programmatically, this script
assumes you have already started a vLLM server using `vllm serve`. It uses:
- OpenAI-compatible API for inference requests
- HTTP endpoints for weight transfer control plane
- CUDA IPC for actual weight data transfer
Prerequisites:
Start a vLLM server with weight transfer enabled and reduced GPU memory
utilization to leave room for the training model:
$ VLLM_SERVER_DEV_MODE=1 VLLM_ALLOW_INSECURE_SERIALIZATION=1 \
vllm serve facebook/opt-125m --enforce-eager \
--weight-transfer-config '{"backend": "ipc"}' \
--load-format dummy \
--gpu-memory-utilization 0.5
Then run this script:
$ python rlhf_http_ipc.py
The example performs the following steps:
* Load the training model on GPU 0 (same GPU as the vLLM server).
* Generate text using the vLLM server via OpenAI-compatible API. The output
is expected to be nonsense because the server is initialized with dummy weights.
* Initialize weight transfer via HTTP endpoint (no-op for IPC).
* Broadcast the real weights from the training model to the vLLM server
using CUDA IPC handles.
* Generate text again to show normal output after the weight update.
"""
import os
import requests
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM
from vllm.distributed.weight_transfer.ipc_engine import (
IPCTrainerSendWeightsArgs,
IPCWeightTransferEngine,
)
BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"
# Enable insecure serialization for IPC handle serialization
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
"""Generate completions using the OpenAI-compatible API."""
results = []
for prompt in prompts:
response = client.completions.create(
model=model,
prompt=prompt,
max_tokens=32,
temperature=0,
)
results.append(response.choices[0].text)
return results
def init_weight_transfer_engine(base_url: str) -> None:
"""Initialize weight transfer via HTTP endpoint (no-op for IPC)."""
url = f"{base_url}/init_weight_transfer_engine"
payload = {"init_info": dict()}
response = requests.post(url, json=payload, timeout=60)
response.raise_for_status()
def pause_generation(base_url: str) -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
response = requests.post(url, timeout=60)
response.raise_for_status()
def resume_generation(base_url: str) -> None:
"""Resume generation via HTTP endpoint."""
url = f"{base_url}/resume"
response = requests.post(url, timeout=60)
response.raise_for_status()
def get_world_size(base_url: str) -> int:
"""Get world size from the vLLM server."""
url = f"{base_url}/get_world_size"
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.json()["world_size"]
def main():
# IPC requires the training model to be on the same GPU as the vLLM server
# The server should be started on GPU 0 with reduced memory utilization
device = "cuda:0"
torch.cuda.set_device(device)
# Load the training model on the same GPU as the server
# Use bfloat16 to reduce memory footprint
print(f"Loading training model: {MODEL_NAME} on {device}")
print(
"Note: Ensure the vLLM server was started with --gpu-memory-utilization 0.5 "
"or lower to leave room for the training model."
)
train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
train_model.to(device)
train_model.eval() # Set to eval mode to save memory
# Create OpenAI client pointing to the vLLM server
client = OpenAI(
base_url=f"{BASE_URL}/v1",
api_key="EMPTY", # vLLM doesn't require an API key by default
)
# Test prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Generate text before weight update. The output is expected to be nonsense
# because the server is initialized with dummy weights.
print("-" * 50)
print("Generating text BEFORE weight update (expect nonsense):")
print("-" * 50)
outputs = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
print("Initializing weight transfer (IPC backend)...")
# Initialize weight transfer on vLLM server (no-op for IPC, but still required)
init_weight_transfer_engine(BASE_URL)
# Pause generation before weight sync
pause_generation(BASE_URL)
# Broadcast weights via IPC handles using HTTP mode
print("Broadcasting weights via CUDA IPC (HTTP)...")
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
IPCWeightTransferEngine.trainer_send_weights(
iterator=train_model.named_parameters(),
trainer_args=trainer_args,
)
# Resume generation after weight sync
resume_generation(BASE_URL)
# Generate text after weight update. The output is expected to be normal
# because the real weights are now loaded.
print("-" * 50)
print("Generating text AFTER weight update:")
print("-" * 50)
outputs_updated = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs_updated):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Note: The training model and IPC handles remain in memory.
# In a real RLHF training loop, you would update the training model
# and create new IPC handles for each weight update.
if __name__ == "__main__":
main()

View File

@@ -39,6 +39,7 @@ from openai import OpenAI
from transformers import AutoModelForCausalLM
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
@@ -214,11 +215,14 @@ def main():
# Broadcast all weights from trainer to vLLM workers
print("Broadcasting weights via NCCL...")
NCCLWeightTransferEngine.trainer_send_weights(
iterator=train_model.named_parameters(),
trainer_args = NCCLTrainerSendWeightsArgs(
group=model_update_group,
packed=True,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=train_model.named_parameters(),
trainer_args=trainer_args,
)
# Wait for update_weights to complete
update_thread.join()

View File

@@ -3,18 +3,26 @@
"""Tests for weight transfer engine backends.
Unit tests for engine classes (parsing, validation, registry).
Integration test for NCCL weight transfer between processes using Ray.
Integration tests for NCCL and IPC weight transfer between processes using Ray.
"""
import base64
import pickle
from unittest.mock import MagicMock
import pytest
import ray
import torch
from torch.multiprocessing.reductions import reduce_tensor
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
from vllm.distributed.weight_transfer.ipc_engine import (
IPCWeightTransferEngine,
IPCWeightTransferInitInfo,
IPCWeightTransferUpdateInfo,
)
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLWeightTransferEngine,
NCCLWeightTransferInitInfo,
@@ -155,9 +163,29 @@ class TestEngineRegistry:
engine = WeightTransferEngineFactory.create_engine(config, parallel_config)
assert isinstance(engine, NCCLWeightTransferEngine)
def test_create_engine_ipc(self):
"""Test factory creates IPC engine."""
config = WeightTransferConfig(backend="ipc")
parallel_config = create_mock_parallel_config()
engine = WeightTransferEngineFactory.create_engine(config, parallel_config)
assert isinstance(engine, IPCWeightTransferEngine)
def test_create_engine_invalid_backend(self):
"""Test factory raises for invalid backend."""
config = WeightTransferConfig(backend="invalid")
# Pydantic validates Literal types at construction, so we can't create
# a config with an invalid backend. Instead, we test by directly
# accessing the registry or using model_construct to bypass validation.
from pydantic import ValidationError
# Test that Pydantic prevents invalid backend at construction
with pytest.raises(ValidationError):
WeightTransferConfig(backend="invalid")
# Test factory error by creating a config with valid backend but
# then manually modifying the backend attribute (bypassing validation)
config = WeightTransferConfig(backend="nccl")
# Use object.__setattr__ to bypass Pydantic validation
object.__setattr__(config, "backend", "invalid")
parallel_config = create_mock_parallel_config()
with pytest.raises(ValueError, match="Invalid weight transfer backend"):
WeightTransferEngineFactory.create_engine(config, parallel_config)
@@ -344,3 +372,426 @@ def test_nccl_weight_transfer_between_processes():
f"Received shape: {result['received_shape']}, "
f"Received sum: {result['received_sum']}"
)
# --- Unit Tests: IPCWeightTransferUpdateInfo Validation ---
class TestIPCWeightTransferUpdateInfoValidation:
"""Test IPCWeightTransferUpdateInfo dataclass validation."""
def test_valid_update_info(self):
"""Test creating valid IPCWeightTransferUpdateInfo."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Create a dummy tensor and IPC handle
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}]
info = IPCWeightTransferUpdateInfo(
names=["layer.weight"],
dtype_names=["float32"],
shapes=[[10, 10]],
ipc_handles=ipc_handles,
)
assert info.names == ["layer.weight"]
assert info.dtype_names == ["float32"]
assert info.shapes == [[10, 10]]
assert len(info.ipc_handles) == 1
def test_mismatched_dtype_names_raises(self):
"""Test that mismatched dtype_names length raises ValueError."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}]
with pytest.raises(ValueError, match="dtype_names"):
IPCWeightTransferUpdateInfo(
names=["layer.weight", "layer.bias"],
dtype_names=["float32"], # Only one dtype
shapes=[[10, 10], [10]],
ipc_handles=ipc_handles,
)
def test_mismatched_shapes_raises(self):
"""Test that mismatched shapes length raises ValueError."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}]
with pytest.raises(ValueError, match="shapes"):
IPCWeightTransferUpdateInfo(
names=["layer.weight", "layer.bias"],
dtype_names=["float32", "float32"],
shapes=[[10, 10]], # Only one shape
ipc_handles=ipc_handles,
)
def test_mismatched_ipc_handles_raises(self):
"""Test that mismatched ipc_handles length raises ValueError."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}] # Only one handle
with pytest.raises(ValueError, match="ipc_handles"):
IPCWeightTransferUpdateInfo(
names=["layer.weight", "layer.bias"],
dtype_names=["float32", "float32"],
shapes=[[10, 10], [10]],
ipc_handles=ipc_handles,
)
def test_valid_update_info_from_pickled(self):
"""Test creating IPCWeightTransferUpdateInfo from pickled handles."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}]
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
info = IPCWeightTransferUpdateInfo(
names=["layer.weight"],
dtype_names=["float32"],
shapes=[[10, 10]],
ipc_handles_pickled=pickled,
)
assert info.ipc_handles == ipc_handles
assert info.ipc_handles_pickled is None
def test_both_handles_and_pickled_raises(self):
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle}]
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
with pytest.raises(ValueError, match="Cannot specify both"):
IPCWeightTransferUpdateInfo(
names=["layer.weight"],
dtype_names=["float32"],
shapes=[[10, 10]],
ipc_handles=ipc_handles,
ipc_handles_pickled=pickled,
)
def test_neither_handles_nor_pickled_raises(self):
"""Test that providing neither ipc_handles nor ipc_handles_pickled raises."""
with pytest.raises(ValueError, match="must be provided"):
IPCWeightTransferUpdateInfo(
names=["layer.weight"],
dtype_names=["float32"],
shapes=[[10, 10]],
)
def test_empty_lists_valid(self):
"""Test that empty lists are valid."""
info = IPCWeightTransferUpdateInfo(
names=[],
dtype_names=[],
shapes=[],
ipc_handles=[],
)
assert len(info.names) == 0
# --- Unit Tests: IPC Engine Parsing ---
class TestIPCEngineParsing:
"""Test IPCWeightTransferEngine parsing methods."""
def test_parse_update_info_valid(self):
"""Test parsing valid update info dict."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="ipc")
parallel_config = create_mock_parallel_config()
engine = IPCWeightTransferEngine(config, parallel_config)
# Create dummy IPC handles
dummy_tensor1 = torch.ones(100, 100, device="cuda:0")
dummy_tensor2 = torch.ones(50, device="cuda:0")
ipc_handle1 = reduce_tensor(dummy_tensor1)
ipc_handle2 = reduce_tensor(dummy_tensor2)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}]
update_info = engine.parse_update_info(
{
"names": ["w1", "w2"],
"dtype_names": ["float32", "bfloat16"],
"shapes": [[100, 100], [50]],
"ipc_handles": ipc_handles,
}
)
assert isinstance(update_info, IPCWeightTransferUpdateInfo)
assert update_info.names == ["w1", "w2"]
assert update_info.dtype_names == ["float32", "bfloat16"]
assert update_info.shapes == [[100, 100], [50]]
assert len(update_info.ipc_handles) == 2
def test_parse_update_info_pickled(self):
"""Test parsing update info with pickled IPC handles (HTTP path)."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="ipc")
parallel_config = create_mock_parallel_config()
engine = IPCWeightTransferEngine(config, parallel_config)
dummy_tensor1 = torch.ones(100, 100, device="cuda:0")
dummy_tensor2 = torch.ones(50, device="cuda:0")
ipc_handle1 = reduce_tensor(dummy_tensor1)
ipc_handle2 = reduce_tensor(dummy_tensor2)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}]
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
update_info = engine.parse_update_info(
{
"names": ["w1", "w2"],
"dtype_names": ["float32", "bfloat16"],
"shapes": [[100, 100], [50]],
"ipc_handles_pickled": pickled,
}
)
assert isinstance(update_info, IPCWeightTransferUpdateInfo)
assert update_info.names == ["w1", "w2"]
assert len(update_info.ipc_handles) == 2
assert update_info.ipc_handles_pickled is None
assert gpu_uuid in update_info.ipc_handles[0]
assert gpu_uuid in update_info.ipc_handles[1]
# --- Integration Test: IPC Weight Transfer Between Ray Tasks ---
def get_physical_gpu_id(device_index: int = 0) -> str:
"""Get physical GPU UUID for a device."""
props = torch.cuda.get_device_properties(device_index)
return str(props.uuid)
@ray.remote(num_gpus=0.5)
class TrainerActor:
"""Trainer actor that creates and holds CUDA IPC handles."""
def __init__(self, tensor_shape: list[int], tensor_dtype: str):
# Create tensor on GPU and keep it alive
dtype = getattr(torch, tensor_dtype)
self.tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda:0")
self.tensor.fill_(42.0) # Fill with 42 to verify correct transfer
# Create IPC handle (tensor must stay alive for IPC to work)
ipc_handle = reduce_tensor(self.tensor)
gpu_uuid = get_physical_gpu_id(0)
torch.cuda.synchronize()
self.ipc_handle_dict = {
"ipc_handle": ipc_handle,
"gpu_uuid": gpu_uuid,
"shape": tensor_shape,
"dtype": tensor_dtype,
}
def get_ipc_handle_dict(self) -> dict:
"""Return IPC handle dict. Tensor stays alive in this actor."""
return self.ipc_handle_dict
@ray.remote(num_gpus=0.5)
def inference_receive_ipc_tensor(
ipc_handle_dict: dict,
mode: str = "ray",
) -> dict:
"""Inference task that receives tensor via IPCWeightTransferEngine."""
from unittest.mock import MagicMock
import torch
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
from vllm.distributed.weight_transfer.ipc_engine import (
IPCWeightTransferEngine,
)
# Create engine with mock parallel config
config = WeightTransferConfig(backend="ipc")
parallel_config = MagicMock(spec=ParallelConfig)
parallel_config.rank = 0
parallel_config.world_size = 1
parallel_config.data_parallel_rank = 0
engine = IPCWeightTransferEngine(config, parallel_config)
# Initialize the engine (no-op for IPC)
init_info = IPCWeightTransferInitInfo()
engine.init_transfer_engine(init_info)
# Receive weights with a no-op load_weights that captures the tensor
received_tensors = []
def noop_load_weights(weights: list[tuple[str, torch.Tensor]]):
for name, tensor in weights:
# Clone tensor to keep it after engine cleans up
received_tensors.append((name, tensor.clone()))
# Build update dict and go through parse_update_info (exercises __post_init__)
ipc_handles = [{ipc_handle_dict["gpu_uuid"]: ipc_handle_dict["ipc_handle"]}]
if mode == "ray":
update_dict: dict = {
"names": ["test.weight"],
"dtype_names": [ipc_handle_dict["dtype"]],
"shapes": [ipc_handle_dict["shape"]],
"ipc_handles": ipc_handles,
}
elif mode == "http":
pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
update_dict = {
"names": ["test.weight"],
"dtype_names": [ipc_handle_dict["dtype"]],
"shapes": [ipc_handle_dict["shape"]],
"ipc_handles_pickled": pickled,
}
else:
raise ValueError(f"Unknown mode: {mode}")
update_info = engine.parse_update_info(update_dict)
engine.receive_weights(update_info, noop_load_weights)
torch.cuda.synchronize()
# Verify we received the tensor
success = False
received_shape = None
received_sum = None
if len(received_tensors) == 1:
name, tensor = received_tensors[0]
received_shape = list(tensor.shape)
received_sum = tensor.sum().item()
# Check shape matches and values are all 42s (trainer sends 42s)
if received_shape == ipc_handle_dict["shape"]:
expected_sum = 42.0 * torch.tensor(ipc_handle_dict["shape"]).prod().item()
if abs(received_sum - expected_sum) < 0.01:
success = True
engine.shutdown()
return {
"success": success,
"received_shape": received_shape,
"received_sum": received_sum,
}
@pytest.mark.skipif(
torch.cuda.device_count() < 1,
reason="Need at least 1 GPU to run IPC weight transfer test.",
)
@pytest.mark.parametrize("mode", ["ray", "http"])
def test_ipc_weight_transfer_between_processes(mode: str):
"""Test IPC weight transfer from trainer to inference process using Ray.
Parametrized over transport modes:
- 'ray': ipc_handles passed directly.
- 'http': ipc_handles pickled + base64-encoded, unpickled via __post_init__.
IPC requires same-GPU access, so we use a placement group to co-locate
the trainer actor and inference task on the same GPU.
"""
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
ray.init(ignore_reinit_error=True)
# Create a placement group to ensure both processes are on the same GPU
# Use fractional GPUs so both tasks can share the same GPU bundle
pg = placement_group([{"GPU": 1, "CPU": 2}])
ray.get(pg.ready())
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
)
# Tensor to transfer: 100x100 filled with 42s
tensor_shape = [100, 100]
tensor_dtype = "float32"
# Create trainer actor that holds the tensor and IPC handle (stays alive)
trainer_actor = TrainerActor.options( # type: ignore[attr-defined]
scheduling_strategy=scheduling_strategy
).remote(tensor_shape, tensor_dtype)
# Get IPC handle dict (tensor stays alive in trainer actor)
ipc_handle_dict = ray.get(trainer_actor.get_ipc_handle_dict.remote())
# Receive tensor in inference process using IPC handles (on same GPU)
# Trainer actor stays alive during this operation
inference_result = ray.get(
inference_receive_ipc_tensor.options(
scheduling_strategy=scheduling_strategy
).remote(ipc_handle_dict, mode=mode)
)
assert inference_result["success"], (
f"IPC weight transfer failed (mode={mode}). "
f"Received shape: {inference_result['received_shape']}, "
f"Received sum: {inference_result['received_sum']}"
)
def test_ipc_receive_weights_missing_gpu_uuid_raises():
"""Test that receive_weights raises if GPU UUID not found in IPC handles."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="ipc")
parallel_config = create_mock_parallel_config()
engine = IPCWeightTransferEngine(config, parallel_config)
# Create IPC handle with wrong GPU UUID
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
wrong_uuid = "wrong-uuid-12345"
ipc_handles = [{wrong_uuid: ipc_handle}]
update_info = IPCWeightTransferUpdateInfo(
names=["w"],
dtype_names=["float32"],
shapes=[[10, 10]],
ipc_handles=ipc_handles,
)
with pytest.raises(ValueError, match="IPC handle not found"):
engine.receive_weights(update_info, lambda x: None)

View File

@@ -37,6 +37,8 @@ CHECK_IMPORTS = {
"vllm/distributed/device_communicators/all_reduce_utils.py",
"vllm/distributed/device_communicators/shm_broadcast.py",
"vllm/distributed/device_communicators/shm_object_storage.py",
"vllm/distributed/weight_transfer/ipc_engine.py",
"tests/distributed/test_weight_transfer.py",
"vllm/utils/hashing.py",
"tests/multimodal/media/test_base.py",
"tests/tokenizers_/test_hf.py",

View File

@@ -9,5 +9,5 @@ from vllm.config.utils import config
class WeightTransferConfig:
"""Configuration for weight transfer during RL training."""
backend: Literal["nccl"] = "nccl"
backend: Literal["nccl", "ipc"] = "nccl"
"""The backend to use for weight transfer."""

View File

@@ -3,7 +3,7 @@
"""Base class for weight transfer engines."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Iterator
from dataclasses import KW_ONLY, dataclass, field
from typing import Any, Generic, TypeVar
@@ -156,3 +156,30 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
This should be called when the worker is shutting down.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
trainer_args: dict[str, Any] | Any,
) -> None:
"""
Send weights from trainer to inference workers.
This is a static method that can be called from the trainer process
to send weights to all inference workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
The tensors should be on the appropriate device for the backend.
trainer_args: Dictionary containing backend-specific arguments needed
to send weights. The structure depends on the backend:
- NCCL: Contains 'group', 'src', 'packed', etc.
- IPC: Contains 'mode' ('http' or 'ray'),
'llm_handle' (for Ray), 'url' (for HTTP), etc.
Example:
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> engine.trainer_send_weights(param_iter, trainer_args)
"""
raise NotImplementedError

View File

@@ -114,3 +114,9 @@ WeightTransferEngineFactory.register_engine(
"vllm.distributed.weight_transfer.nccl_engine",
"NCCLWeightTransferEngine",
)
WeightTransferEngineFactory.register_engine(
"ipc",
"vllm.distributed.weight_transfer.ipc_engine",
"IPCWeightTransferEngine",
)

View File

@@ -0,0 +1,291 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""IPC-based weight transfer engine using CUDA IPC for communication."""
import base64
import pickle
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass
from typing import Any
import requests
import torch
from torch.multiprocessing.reductions import reduce_tensor
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferEngine,
WeightTransferInitInfo,
WeightTransferUpdateInfo,
)
@dataclass
class IPCTrainerSendWeightsArgs:
"""Arguments for IPC trainer_send_weights method."""
mode: str
"""Transport mode: 'http' or 'ray'."""
llm_handle: Any = None
"""Ray ObjectRef to LLM handle (required for 'ray' mode)."""
url: str | None = None
"""Base URL for HTTP endpoint (required for 'http' mode)."""
def __post_init__(self):
"""Validate that required arguments are provided for the selected mode."""
if self.mode == "ray" and self.llm_handle is None:
raise ValueError("llm_handle is required for 'ray' mode")
if self.mode == "http" and self.url is None:
raise ValueError("url is required for 'http' mode")
if self.mode not in ("ray", "http"):
raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}")
@dataclass
class IPCWeightTransferInitInfo(WeightTransferInitInfo):
"""Initialization info for IPC weight transfer backend. No init needed for IPC."""
pass
@dataclass
class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
"""Update info for IPC weight transfer backend.
Accepts IPC handles either directly via ``ipc_handles`` (Ray transport)
or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport).
Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set
it is unpickled into ``ipc_handles`` during ``__post_init__``.
"""
names: list[str]
dtype_names: list[str]
shapes: list[list[int]]
ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None
"""IPC handles mapping physical GPU UUID to (func, args) tuple.
Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples."""
ipc_handles_pickled: str | None = None
"""Base64-encoded pickled IPC handles, used for HTTP transport."""
def __post_init__(self):
if self.ipc_handles_pickled is not None:
if self.ipc_handles is not None:
raise ValueError(
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
)
self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
self.ipc_handles_pickled = None
if self.ipc_handles is None:
raise ValueError(
"Either `ipc_handles` or `ipc_handles_pickled` must be provided"
)
num_params = len(self.names)
if len(self.dtype_names) != num_params:
raise ValueError(
f"`dtype_names` should be of the same size as `names`: "
f"got {len(self.dtype_names)} and {len(self.names)}"
)
if len(self.shapes) != num_params:
raise ValueError(
f"`shapes` should be of the same size as `names`: "
f"got {len(self.shapes)} and {len(self.names)}"
)
if len(self.ipc_handles) != num_params:
raise ValueError(
f"`ipc_handles` should be of the same size as `names`: "
f"got {len(self.ipc_handles)} and {len(self.names)}"
)
class IPCWeightTransferEngine(
WeightTransferEngine[IPCWeightTransferInitInfo, IPCWeightTransferUpdateInfo]
):
"""
Weight transfer engine using CUDA IPC for communication between trainer and workers.
This implementation uses CUDA IPC to transfer weights from the trainer (rank 0)
to all inference workers in a process group. IPC handles are used to share
memory between processes on the same node.
"""
# Define backend-specific dataclass types
init_info_cls = IPCWeightTransferInitInfo
update_info_cls = IPCWeightTransferUpdateInfo
def __init__(
self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
"""
Initialize the IPC weight transfer engine.
Args:
config: The configuration for the weight transfer engine
parallel_config: The configuration for the parallel setup
"""
super().__init__(config, parallel_config)
def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None:
"""
Initialize the weight transfer mechanism.
This is called once at the beginning of training.
No initialization needed for IPC backend.
Args:
init_info: IPC initialization info (empty)
"""
pass
def receive_weights(
self,
update_info: IPCWeightTransferUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
"""
Receive weights from the trainer via CUDA IPC handles.
Args:
update_info: IPC update info containing parameter names, dtypes, shapes,
and IPC handles. Each IPC handle is a mapping between physical
GPU UUID and the IPC handle tuple (func, args).
load_weights: Callable that loads weights into the model. Called
incrementally for each weight to avoid OOM.
"""
assert update_info.ipc_handles is not None
weights = []
for name, _dtype_name, _shape, ipc_handle in zip(
update_info.names,
update_info.dtype_names,
update_info.shapes,
update_info.ipc_handles,
):
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
physical_gpu_id = str(props.uuid)
if physical_gpu_id not in ipc_handle:
raise ValueError(
f"IPC handle not found for GPU UUID {physical_gpu_id}. "
f"Available UUIDs: {list(ipc_handle.keys())}"
)
handle = ipc_handle[physical_gpu_id]
func, args = handle
list_args = list(args) # type: ignore
# Index 6 is the device_index parameter in torch's
# IPC handle tuple (rebuild_cuda_tensor). Update it
# to the current device since the logical index can
# differ between sender and receiver.
list_args[6] = device_index
weight = func(*list_args) # type: ignore
weights.append((name, weight))
load_weights(weights)
def shutdown(self) -> None:
"""
Shutdown the weight transfer engine.
"""
pass
@staticmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs,
) -> None:
"""
Send weights from trainer to inference workers via CUDA IPC.
Supports two modes:
- 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
- 'http': Sends weights via HTTP POST to a vLLM HTTP server
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
Tensors should be on the same GPU as the inference workers.
trainer_args: Dictionary containing IPC-specific arguments.
Should contain keys from IPCTrainerSendWeightsArgs:
- mode: 'ray' or 'http'
- llm_handle: Ray ObjectRef (for 'ray' mode)
- url: Base URL string (for 'http' mode)
Example (Ray mode):
>>> from vllm.distributed.weight_transfer.ipc_engine import (
... IPCWeightTransferEngine,
... IPCTrainerSendWeightsArgs,
... )
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
Example (HTTP mode):
>>> args = IPCTrainerSendWeightsArgs(
... mode="http", url="http://localhost:8000"
... )
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
"""
# Parse trainer args - accept either dict or dataclass instance
if isinstance(trainer_args, dict):
args = IPCTrainerSendWeightsArgs(**trainer_args)
else:
args = trainer_args
# Get physical GPU UUID
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
gpu_uuid = str(props.uuid)
# Collect weight metadata and create IPC handles
names = []
dtype_names = []
shapes = []
ipc_handles = []
for name, tensor in iterator:
names.append(name)
dtype_names.append(str(tensor.dtype).split(".")[-1])
shapes.append(list(tensor.shape))
# Create IPC handle for this weight tensor
# The tensor must remain in memory for IPC to work
weight = tensor.detach().contiguous()
ipc_handle = reduce_tensor(weight)
ipc_handles.append({gpu_uuid: ipc_handle})
# Send weights based on mode
if args.mode == "ray":
# Ray mode: send via Ray RPC
import ray
update_info = asdict(
IPCWeightTransferUpdateInfo(
names=names,
dtype_names=dtype_names,
shapes=shapes,
ipc_handles=ipc_handles,
)
)
ray.get(
args.llm_handle.update_weights.remote(dict(update_info=update_info))
)
elif args.mode == "http":
# HTTP mode: send via HTTP POST with pickled handles
# Pickle and base64 encode IPC handles for HTTP transmission
pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode(
"utf-8"
)
url = f"{args.url}/update_weights"
payload = {
"update_info": {
"names": names,
"dtype_names": dtype_names,
"shapes": shapes,
"ipc_handles_pickled": pickled_handles,
}
}
response = requests.post(url, json=payload, timeout=300)
response.raise_for_status()

View File

@@ -35,6 +35,32 @@ class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
world_size: int
@dataclass
class NCCLTrainerSendWeightsArgs:
"""Arguments for NCCL trainer_send_weights method."""
group: Any
"""Process group (PyNcclCommunicator) for NCCL communication."""
src: int = 0
"""Source rank (default 0, trainer is typically rank 0)."""
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None
"""Optional function to apply to each (name, tensor) pair before broadcasting.
If None, extracts just the tensor."""
packed: bool = False
"""Whether to use packed tensor broadcasting for efficiency.
When True, multiple tensors are batched together before broadcasting
to reduce NCCL communication overhead."""
stream: torch.cuda.Stream | None = None
"""CUDA stream to use for broadcasting if packed is False.
If packed is True, new streams will be created for each buffer."""
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
"""Size in bytes for each packed tensor buffer.
Must match the value used in NCCLWeightTransferUpdateInfo."""
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
"""Number of buffers for double/triple buffering during packed transfer.
Must match the value used in NCCLWeightTransferUpdateInfo."""
@dataclass
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
"""Update info for NCCL weight transfer backend."""
@@ -47,7 +73,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
When True, multiple tensors are batched together before broadcasting
to reduce NCCL communication overhead."""
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
"""Size in bytes for each packed tensor buffer. Default is 1GB.
"""Size in bytes for each packed tensor buffer.
Both producer and consumer must use the same value."""
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
"""Number of buffers for double/triple buffering during packed transfer.
@@ -186,47 +212,38 @@ class NCCLWeightTransferEngine(
@staticmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
group: Any,
src: int = 0,
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
| None = None,
packed: bool = False,
stream: torch.cuda.Stream | None = None,
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs,
) -> None:
"""Broadcast weights from trainer to vLLM workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples
group: Process group (PyNcclCommunicator)
src: Source rank (default 0, trainer is typically rank 0)
post_iter_func: Optional function to apply to each (name, tensor) pair
before broadcasting. If None, extracts just the tensor.
packed: Whether to use packed tensor broadcasting for efficiency.
When True, multiple tensors are batched together before
broadcasting to reduce NCCL communication overhead.
stream: CUDA stream to use for broadcasting if packed is False.
If packed is True, new streams will be created for each buffer.
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
Must match the value used in NCCLWeightTransferUpdateInfo.
packed_num_buffers: Number of buffers for double/triple buffering.
Must match the value used in NCCLWeightTransferUpdateInfo.
trainer_args: Dictionary or NCCLTrainerSendWeightsArgs instance containing
NCCL-specific arguments. If a dict, should contain keys from
NCCLTrainerSendWeightsArgs.
Example:
>>> from vllm.distributed.weight_transfer.nccl_engine import (
... NCCLWeightTransferEngine,
... NCCLTrainerSendWeightsArgs,
... )
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> NCCLWeightTransferEngine.trainer_send_weights(
... param_iter, group, packed=True
... )
>>> args = NCCLTrainerSendWeightsArgs(group=group, packed=True)
>>> NCCLWeightTransferEngine.trainer_send_weights(param_iter, args)
"""
if post_iter_func is None:
# Parse trainer args - accept either dict or dataclass instance
if isinstance(trainer_args, dict):
args = NCCLTrainerSendWeightsArgs(**trainer_args)
else:
args = trainer_args
if args.post_iter_func is None:
# Default: extract just the tensor from (name, tensor) tuple
post_iter_func = lambda x: x[1]
else:
post_iter_func = args.post_iter_func
if packed:
if args.packed:
# Use packed tensor broadcasting for efficiency
from vllm.distributed.weight_transfer.packed_tensor import (
packed_broadcast_producer,
@@ -234,18 +251,20 @@ class NCCLWeightTransferEngine(
packed_broadcast_producer(
iterator=iterator,
group=group,
src=src,
group=args.group,
src=args.src,
post_iter_func=post_iter_func,
buffer_size_bytes=packed_buffer_size_bytes,
num_buffers=packed_num_buffers,
buffer_size_bytes=args.packed_buffer_size_bytes,
num_buffers=args.packed_num_buffers,
)
else:
# Use simple one-by-one broadcasting
for item in iterator:
tensor = post_iter_func(item)
group.broadcast(
tensor, src=src, stream=stream or torch.cuda.current_stream()
args.group.broadcast(
tensor,
src=args.src,
stream=args.stream or torch.cuda.current_stream(),
)
@staticmethod