[RLHF] use worker_extension_cls for compatibility with V0 and V1 (#14185)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
105
examples/offline_inference/rlhf_utils.py
Normal file
105
examples/offline_inference/rlhf_utils.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
|
||||
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size,
|
||||
device):
|
||||
"""
|
||||
vLLM provides `StatelessProcessGroup` to create a process group
|
||||
without considering the global process group in torch.distributed.
|
||||
It is recommended to create `StatelessProcessGroup`, and then initialize
|
||||
the data-plane communication (NCCL) between external (train processes)
|
||||
and vLLM workers.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
pg = StatelessProcessGroup.create(host=master_address,
|
||||
port=master_port,
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
pynccl = PyNcclCommunicator(pg, device=device)
|
||||
return pynccl
|
||||
|
||||
|
||||
class WorkerExtension:
|
||||
"""
|
||||
The class for vLLM's worker to inherit from.
|
||||
By defining an extension class, the code can work no matter what is
|
||||
the underlying worker class. This way, the code can be compatible
|
||||
with both vLLM V0 and V1.
|
||||
NOTE: we define this class in a separate module, and the main module
|
||||
should pass the full qualified name as `worker_extension_cls` argument.
|
||||
"""
|
||||
|
||||
def init_weight_update_group(self, master_address, master_port,
|
||||
rank_offset, world_size):
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
rank = get_world_group().rank + rank_offset
|
||||
self.model_update_group = stateless_init_process_group(
|
||||
master_address,
|
||||
master_port,
|
||||
rank,
|
||||
world_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
def update_weight(self, name, dtype, shape):
|
||||
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||||
self.model_update_group.broadcast(weight,
|
||||
src=0,
|
||||
stream=torch.cuda.current_stream())
|
||||
|
||||
self.model_runner.model.load_weights(weights=[(name, weight)])
|
||||
|
||||
del weight
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
|
||||
|
||||
class ColocateWorkerExtension:
|
||||
"""
|
||||
The class for vLLM's worker to inherit from, in the colocate setting.
|
||||
By defining an extension class, the code can work no matter what is
|
||||
the underlying worker class. This way, the code can be compatible
|
||||
with both vLLM V0 and V1.
|
||||
NOTE: we define this class in a separate module, and the main module
|
||||
should pass the full qualified name as `worker_extension_cls` argument.
|
||||
"""
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from vllm.platforms import current_platform
|
||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||
return self.device_uuid
|
||||
|
||||
def update_weights_from_ipc_handles(self, ipc_handles):
|
||||
handles = ipc_handles[self.device_uuid]
|
||||
device_id = self.device.index
|
||||
weights = []
|
||||
for name, handle in handles.items():
|
||||
func, args = handle
|
||||
list_args = list(args)
|
||||
# the key is to change device id to the current device id
|
||||
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||||
list_args[6] = device_id
|
||||
tensor = func(*list_args)
|
||||
weights.append((name, tensor))
|
||||
self.model_runner.model.load_weights(weights=weights)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
Reference in New Issue
Block a user