Convert examples to ruff-format (#18400)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,21 +2,20 @@
|
||||
import torch
|
||||
|
||||
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size,
|
||||
device):
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
|
||||
"""
|
||||
vLLM provides `StatelessProcessGroup` to create a process group
|
||||
without considering the global process group in torch.distributed.
|
||||
It is recommended to create `StatelessProcessGroup`, and then initialize
|
||||
the data-plane communication (NCCL) between external (train processes)
|
||||
the data-plane communication (NCCL) between external (train processes)
|
||||
and vLLM workers.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
pg = StatelessProcessGroup.create(host=master_address,
|
||||
port=master_port,
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
|
||||
pg = StatelessProcessGroup.create(
|
||||
host=master_address, port=master_port, rank=rank, world_size=world_size
|
||||
)
|
||||
pynccl = PyNcclCommunicator(pg, device=device)
|
||||
return pynccl
|
||||
|
||||
@@ -31,9 +30,11 @@ class WorkerExtension:
|
||||
should pass the full qualified name as `worker_extension_cls` argument.
|
||||
"""
|
||||
|
||||
def init_weight_update_group(self, master_address, master_port,
|
||||
rank_offset, world_size):
|
||||
def init_weight_update_group(
|
||||
self, master_address, master_port, rank_offset, world_size
|
||||
):
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
rank = get_world_group().rank + rank_offset
|
||||
self.model_update_group = stateless_init_process_group(
|
||||
master_address,
|
||||
@@ -45,9 +46,9 @@ class WorkerExtension:
|
||||
|
||||
def update_weight(self, name, dtype, shape):
|
||||
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||||
self.model_update_group.broadcast(weight,
|
||||
src=0,
|
||||
stream=torch.cuda.current_stream())
|
||||
self.model_update_group.broadcast(
|
||||
weight, src=0, stream=torch.cuda.current_stream()
|
||||
)
|
||||
|
||||
self.model_runner.model.load_weights(weights=[(name, weight)])
|
||||
|
||||
@@ -59,8 +60,7 @@ class WorkerExtension:
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
|
||||
|
||||
@@ -76,6 +76,7 @@ class ColocateWorkerExtension:
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||
return self.device_uuid
|
||||
|
||||
@@ -100,6 +101,5 @@ class ColocateWorkerExtension:
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(
|
||||
p, torch.zeros_like(p))
|
||||
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
|
||||
Reference in New Issue
Block a user