[RL] fast weight update with zmq + ipc handles (#24295)

Signed-off-by: huangweixiao <huangweixiao@msh.team>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Weixiao Huang
2025-09-09 16:57:46 +08:00
committed by GitHub
parent 1116590b16
commit 3d2a2de8f7
2 changed files with 152 additions and 33 deletions

View File

@@ -28,12 +28,15 @@ Learn more about Ray placement groups:
https://docs.ray.io/en/latest/placement-groups.html
"""
import gc
import os
import ray
import torch
import zmq
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.multiprocessing.reductions import reduce_tensor
from vllm import LLM
@@ -86,20 +89,72 @@ class RayTrainingActor:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(0)
self.zmq_context = zmq.Context()
self.zmq_address_counter = 0
self.zmq_handle = None
def report_device_id(self) -> str:
return self.device_uuid
def get_weight_ipc_handles(self):
from torch.multiprocessing.reductions import reduce_tensor
def get_zmq_handles(self) -> dict[str, str]:
suffix = f"{self.device_uuid}-{self.zmq_address_counter}"
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock"
self.zmq_address_counter += 1
return {self.device_uuid: self.zmq_handle}
data = {}
for name, p in self.model.named_parameters():
# A training actor might hold only a subset of the weights and may
# need to gather weights from other actors. For demonstration
# purposes, each training actor owns the full weight set.
data[name] = reduce_tensor(p.detach())
return {self.device_uuid: data}
def update_weights(self):
# align size to avoid misaligned address
align_size = 256
def get_size(p: torch.Tensor) -> int:
return (p.nbytes + align_size - 1) // align_size * align_size
named_parameters: dict[str, torch.nn.Parameter] = dict(
self.model.named_parameters()
)
max_tensor_size = max(get_size(p) for p in named_parameters.values())
# use max_tensor_size * 2 as buffer size
buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0")
s = self.zmq_context.socket(zmq.REQ)
s.bind(self.zmq_handle)
handle = reduce_tensor(buffer)
offset = 0
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
named_tensors: list[dict] = []
real_tensors: list[torch.Tensor] = []
for name, p in named_parameters.items():
size = get_size(p)
if offset + size > buffer.numel():
buckets.append((named_tensors, real_tensors))
named_tensors, real_tensors = [], []
offset = 0
# assume tensors are contiguous
named_tensors.append(
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
)
real_tensors.append(p)
offset += size
if named_tensors:
buckets.append((named_tensors, real_tensors))
s.send_pyobj(handle)
s.recv()
for named_tensors, real_tensors in buckets:
offset = 0
for p in real_tensors:
buffer[offset : offset + p.nbytes].data.copy_(
p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
)
offset += get_size(p)
torch.cuda.synchronize()
s.send_pyobj(named_tensors)
s.recv()
s.send_pyobj(None)
s.recv()
s.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
# Ray manages four GPUs.
@@ -175,18 +230,22 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# the second inference engine.
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
print("Gather all the IPC handles from the training actors.")
ipc_handles = {}
print("Gather all the ZMQ handles from the training actors.")
zmq_handles = {}
for actor in training_actors:
ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
print(f"ZMQ handles: {zmq_handles}")
print("Update the weights of the inference engines.")
for llm in inference_engines:
ray.get(
llm.collective_rpc.remote(
"update_weights_from_ipc_handles", args=(ipc_handles,)
)
)
ray.get(
[actor.update_weights.remote() for actor in training_actors]
+ [
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
for llm in inference_engines
]
)
print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))