[RL] fast weight update with zmq + ipc handles (#24295)
Signed-off-by: huangweixiao <huangweixiao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
from typing import Callable, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
|
||||
@@ -66,6 +70,27 @@ class WorkerExtension:
|
||||
return weights_updated
|
||||
|
||||
|
||||
def rebuild_ipc(
|
||||
handle: tuple[Callable, tuple], device_id: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
func, args = handle
|
||||
list_args = list(args)
|
||||
if device_id is not None:
|
||||
# the key is to change device id to the current device id
|
||||
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||||
list_args[6] = device_id
|
||||
buffer = func(*list_args)
|
||||
return buffer
|
||||
|
||||
|
||||
class FlattenedTensorMetadata(TypedDict):
|
||||
name: str
|
||||
shape: torch.Size
|
||||
dtype: torch.dtype
|
||||
# specify the start offset of this tensor in shared ipc_buffer tensor
|
||||
offset: int
|
||||
|
||||
|
||||
class ColocateWorkerExtension:
|
||||
"""
|
||||
The class for vLLM's worker to inherit from, in the colocate setting.
|
||||
@@ -76,27 +101,62 @@ class ColocateWorkerExtension:
|
||||
should pass the full qualified name as `worker_extension_cls` argument.
|
||||
"""
|
||||
|
||||
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
|
||||
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
||||
|
||||
assert self.device is not None
|
||||
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
|
||||
self._zmq_ctx = zmq.Context()
|
||||
socket = self._zmq_ctx.socket(zmq.REP)
|
||||
socket.connect(zmq_handles[self.report_device_id()])
|
||||
buffer: Optional[torch.Tensor] = None
|
||||
while True:
|
||||
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = (
|
||||
socket.recv_pyobj()
|
||||
)
|
||||
if payload is None:
|
||||
# means the update is done
|
||||
process_weights_after_loading(
|
||||
self.model_runner.model, self.model_config, self.device
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
socket.send(b"")
|
||||
break
|
||||
if isinstance(payload, tuple):
|
||||
# an ipc handle that vLLM can use `func, args = handle`
|
||||
# and `func(*args)` to rebuild GPU tensor.
|
||||
buffer = rebuild_ipc(payload, self.device.index)
|
||||
assert buffer.dtype == torch.uint8
|
||||
socket.send(b"")
|
||||
continue
|
||||
assert isinstance(payload, list)
|
||||
assert buffer is not None
|
||||
weights = []
|
||||
for item in payload:
|
||||
shape = item["shape"]
|
||||
if isinstance(shape, (list, tuple)):
|
||||
shape = torch.Size(shape)
|
||||
assert isinstance(shape, torch.Size)
|
||||
dtype, offset = item["dtype"], item["offset"]
|
||||
size = dtype.itemsize * shape.numel()
|
||||
tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
|
||||
weights.append((item["name"], tensor))
|
||||
self.model_runner.model.load_weights(weights=weights)
|
||||
del weights
|
||||
torch.cuda.synchronize()
|
||||
socket.send(b"")
|
||||
|
||||
socket.close()
|
||||
del buffer
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||
return self.device_uuid
|
||||
|
||||
def update_weights_from_ipc_handles(self, ipc_handles):
|
||||
handles = ipc_handles[self.device_uuid]
|
||||
device_id = self.device.index
|
||||
weights = []
|
||||
for name, handle in handles.items():
|
||||
func, args = handle
|
||||
list_args = list(args)
|
||||
# the key is to change device id to the current device id
|
||||
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||||
list_args[6] = device_id
|
||||
tensor = func(*list_args)
|
||||
weights.append((name, tensor))
|
||||
self.model_runner.model.load_weights(weights=weights)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
|
||||
Reference in New Issue
Block a user