[Feat][RL][1/2] Native Weight Syncing API: NCCL (#31943)

Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: Aaron Hao <ahao@anyscale.com>
Co-authored-by: SumanthRH <sumanthrh99@gmail.com>
This commit is contained in:
Aaron Hao
2026-02-05 09:13:23 -08:00
committed by GitHub
parent 82914d2ae8
commit c1858b7ec8
27 changed files with 2974 additions and 2 deletions

View File

@@ -33,6 +33,7 @@ from vllm.distributed.parallel_state import (
get_pp_group,
get_tp_group,
)
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.models.interfaces import is_mixture_of_experts
@@ -89,6 +90,16 @@ class Worker(WorkerBase):
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Weight transfer engine (initialized on-demand)
self.weight_transfer_engine = (
WeightTransferEngineFactory.create_engine(
self.vllm_config.weight_transfer_config,
self.vllm_config.parallel_config,
)
if self.vllm_config.weight_transfer_config is not None
else None
)
# Torch/CUDA profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None
profiler_config = vllm_config.profiler_config
@@ -932,6 +943,69 @@ class Worker(WorkerBase):
tensorizer_config=tensorizer_config,
)
def init_weight_transfer_engine(self, init_info: dict) -> None:
"""
Initialize weight transfer mechanism.
For NCCL backend, this creates a process group with the trainer.
Args:
init_info: Dictionary containing backend-specific initialization info
"""
if self.weight_transfer_engine is None:
raise RuntimeError(
"Weight transfer not configured. "
"Please set weight_transfer_config to enable weight transfer."
)
# Parse dict into backend-specific typed dataclass
typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
self.weight_transfer_engine.init_transfer_engine(typed_init_info)
def update_weights(self, update_info: dict) -> None:
"""
Batched weight update from the trainer.
Args:
update_info: Dictionary containing backend-specific update info
"""
if self.weight_transfer_engine is None:
raise RuntimeError(
"Weight transfer not configured. "
"Please set weight_transfer_config to enable weight transfer."
)
# Parse dict into backend-specific typed dataclass
typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)
model = self.model_runner.model
if typed_update_info.is_checkpoint_format:
from vllm.model_executor.model_loader.reload import (
finalize_layerwise_reload,
initialize_layerwise_reload,
)
# Use layerwise reload pattern for checkpoint format weights
with torch.device(self.device):
initialize_layerwise_reload(model)
self.weight_transfer_engine.receive_weights(
typed_update_info,
load_weights=model.load_weights,
)
finalize_layerwise_reload(model, self.model_config)
else:
# Weights are already in kernel format, copy directly
def load_weights_direct(
weights: list[tuple[str, torch.Tensor]],
) -> None:
for name, weight in weights:
param = model.get_parameter(name)
param.copy_(weight)
self.weight_transfer_engine.receive_weights(
typed_update_info,
load_weights=load_weights_direct,
)
def shutdown(self) -> None:
# has_kv_transfer_group can be None during interpreter shutdown.
if ensure_kv_transfer_shutdown is not None:
@@ -939,6 +1013,9 @@ class Worker(WorkerBase):
if self.profiler is not None:
self.profiler.shutdown()
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
weight_transfer_engine.shutdown()
def init_worker_distributed_environment(
vllm_config: VllmConfig,