[Feat][RL][1/2] Native Weight Syncing API: NCCL (#31943)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Aaron Hao <ahao@anyscale.com> Co-authored-by: SumanthRH <sumanthrh99@gmail.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
@@ -89,6 +90,16 @@ class Worker(WorkerBase):
|
||||
# Buffers saved before sleep
|
||||
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Weight transfer engine (initialized on-demand)
|
||||
self.weight_transfer_engine = (
|
||||
WeightTransferEngineFactory.create_engine(
|
||||
self.vllm_config.weight_transfer_config,
|
||||
self.vllm_config.parallel_config,
|
||||
)
|
||||
if self.vllm_config.weight_transfer_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Torch/CUDA profiler. Enabled and configured through profiler_config.
|
||||
self.profiler: Any | None = None
|
||||
profiler_config = vllm_config.profiler_config
|
||||
@@ -932,6 +943,69 @@ class Worker(WorkerBase):
|
||||
tensorizer_config=tensorizer_config,
|
||||
)
|
||||
|
||||
def init_weight_transfer_engine(self, init_info: dict) -> None:
|
||||
"""
|
||||
Initialize weight transfer mechanism.
|
||||
For NCCL backend, this creates a process group with the trainer.
|
||||
|
||||
Args:
|
||||
init_info: Dictionary containing backend-specific initialization info
|
||||
"""
|
||||
if self.weight_transfer_engine is None:
|
||||
raise RuntimeError(
|
||||
"Weight transfer not configured. "
|
||||
"Please set weight_transfer_config to enable weight transfer."
|
||||
)
|
||||
# Parse dict into backend-specific typed dataclass
|
||||
typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
|
||||
self.weight_transfer_engine.init_transfer_engine(typed_init_info)
|
||||
|
||||
def update_weights(self, update_info: dict) -> None:
|
||||
"""
|
||||
Batched weight update from the trainer.
|
||||
|
||||
Args:
|
||||
update_info: Dictionary containing backend-specific update info
|
||||
"""
|
||||
if self.weight_transfer_engine is None:
|
||||
raise RuntimeError(
|
||||
"Weight transfer not configured. "
|
||||
"Please set weight_transfer_config to enable weight transfer."
|
||||
)
|
||||
|
||||
# Parse dict into backend-specific typed dataclass
|
||||
typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)
|
||||
|
||||
model = self.model_runner.model
|
||||
|
||||
if typed_update_info.is_checkpoint_format:
|
||||
from vllm.model_executor.model_loader.reload import (
|
||||
finalize_layerwise_reload,
|
||||
initialize_layerwise_reload,
|
||||
)
|
||||
|
||||
# Use layerwise reload pattern for checkpoint format weights
|
||||
with torch.device(self.device):
|
||||
initialize_layerwise_reload(model)
|
||||
self.weight_transfer_engine.receive_weights(
|
||||
typed_update_info,
|
||||
load_weights=model.load_weights,
|
||||
)
|
||||
finalize_layerwise_reload(model, self.model_config)
|
||||
else:
|
||||
# Weights are already in kernel format, copy directly
|
||||
def load_weights_direct(
|
||||
weights: list[tuple[str, torch.Tensor]],
|
||||
) -> None:
|
||||
for name, weight in weights:
|
||||
param = model.get_parameter(name)
|
||||
param.copy_(weight)
|
||||
|
||||
self.weight_transfer_engine.receive_weights(
|
||||
typed_update_info,
|
||||
load_weights=load_weights_direct,
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
# has_kv_transfer_group can be None during interpreter shutdown.
|
||||
if ensure_kv_transfer_shutdown is not None:
|
||||
@@ -939,6 +1013,9 @@ class Worker(WorkerBase):
|
||||
if self.profiler is not None:
|
||||
self.profiler.shutdown()
|
||||
|
||||
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
|
||||
weight_transfer_engine.shutdown()
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
Reference in New Issue
Block a user